diff --git a/pyproject.toml b/pyproject.toml index abe9199092..025807fedb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -167,12 +167,15 @@ asyncio_mode = "auto" [tool.mypy] plugins = [] ignore_missing_imports = true -strict = false +strict = true follow_imports = "silent" -strict_optional = false disable_error_code = ["empty-body"] exclude = ["doc/code/", "pyrit/auxiliary_attacks/"] +[[tool.mypy.overrides]] +module = "pyrit.prompt_target.hugging_face.*" +disallow_untyped_calls = false + [tool.uv] constraint-dependencies = [ "aiohttp>=3.13.4", diff --git a/pyrit/analytics/result_analysis.py b/pyrit/analytics/result_analysis.py index 3c830050b6..a403d1aa37 100644 --- a/pyrit/analytics/result_analysis.py +++ b/pyrit/analytics/result_analysis.py @@ -62,9 +62,8 @@ def analyze_results(attack_results: list[AttackResult]) -> dict[str, AttackStats raise TypeError(f"Expected AttackResult, got {type(attack).__name__}: {attack!r}") outcome = attack.outcome - attack_type = ( - attack.get_attack_strategy_identifier().class_name if attack.get_attack_strategy_identifier() else "unknown" - ) + _strategy_id = attack.get_attack_strategy_identifier() + attack_type = _strategy_id.class_name if _strategy_id is not None else "unknown" if outcome == AttackOutcome.SUCCESS: overall_counts["successes"] += 1 diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index 4149749e45..2cf54eb634 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -296,7 +296,7 @@ def get_access_token_from_interactive_login(scope: str) -> str: """ try: token_provider = get_bearer_token_provider(InteractiveBrowserCredential(), scope) - return token_provider() + return str(token_provider()) except Exception as e: logger.error(f"Failed to obtain token for '{scope}': {e}") raise diff --git a/pyrit/auth/copilot_authenticator.py b/pyrit/auth/copilot_authenticator.py index ea85979fb6..d0ccff4058 100644 --- a/pyrit/auth/copilot_authenticator.py +++ b/pyrit/auth/copilot_authenticator.py @@ -353,6 +353,9 @@ async def _run_playwright_browser_automation(self) -> Optional[str]: Returns: Optional[str]: The bearer token if successfully retrieved, None otherwise. + + Raises: + ValueError: If the username is not set. """ from playwright.async_api import async_playwright @@ -415,11 +418,15 @@ async def response_handler(response: Any) -> None: logger.info("Waiting for email input...") await page.wait_for_selector("#i0116", timeout=self._elements_timeout) + if self._username is None: + raise ValueError("Username is not set") await page.fill("#i0116", self._username) await page.click("#idSIButton9") logger.info("Waiting for password input...") await page.wait_for_selector("#i0118", timeout=self._elements_timeout) + if self._password is None: + raise ValueError("Password is not set") await page.fill("#i0118", self._password) await page.click("#idSIButton9") @@ -450,7 +457,7 @@ async def response_handler(response: Any) -> None: else: logger.error(f"Failed to retrieve bearer token within {self._token_capture_timeout} seconds.") - return bearer_token # type: ignore[no-any-return] + return bearer_token except Exception as e: logger.error("Failed to retrieve access token using Playwright.") diff --git a/pyrit/backend/middleware/auth.py b/pyrit/backend/middleware/auth.py index 416f4becb5..db7de281ea 100644 --- a/pyrit/backend/middleware/auth.py +++ b/pyrit/backend/middleware/auth.py @@ -61,6 +61,7 @@ def __init__(self, app: ASGIApp) -> None: self._allowed_group_ids: set[str] = {g.strip() for g in groups_raw.split(",") if g.strip()} self._enabled = bool(self._tenant_id and self._client_id) + self._jwks_client: PyJWKClient | None if self._enabled: jwks_url = f"https://login.microsoftonline.com/{self._tenant_id}/discovery/v2.0/keys" self._jwks_client = PyJWKClient(jwks_url, cache_keys=True) @@ -251,6 +252,8 @@ def _validate_token(self, token: str) -> tuple[Optional[AuthenticatedUser], dict Tuple of (AuthenticatedUser, claims) if valid, (None, {}) if validation fails. """ try: + if self._jwks_client is None: + raise RuntimeError("JWKS client not initialized") signing_key = self._jwks_client.get_signing_key_from_jwt(token) claims = jwt.decode( token, diff --git a/pyrit/backend/routes/media.py b/pyrit/backend/routes/media.py index 901fe0520a..6eafb6b5a3 100644 --- a/pyrit/backend/routes/media.py +++ b/pyrit/backend/routes/media.py @@ -123,6 +123,8 @@ async def serve_media_async( """ try: memory = CentralMemory.get_memory_instance() + if not memory.results_path: + raise HTTPException(status_code=500, detail="Memory results_path is not configured.") allowed_root = os.path.realpath(memory.results_path) except Exception as exc: raise HTTPException(status_code=500, detail="Memory not initialized; cannot determine results path.") from exc diff --git a/pyrit/backend/routes/version.py b/pyrit/backend/routes/version.py index f550084eb8..b59d176158 100644 --- a/pyrit/backend/routes/version.py +++ b/pyrit/backend/routes/version.py @@ -67,7 +67,7 @@ async def get_version_async(request: Request) -> VersionResponse: memory = CentralMemory.get_memory_instance() db_type = type(memory).__name__ db_name = None - if memory.engine.url.database: + if memory.engine is not None and memory.engine.url.database: db_name = memory.engine.url.database.split("?")[0] database_info = f"{db_type} ({db_name})" if db_name else f"{db_type} (None)" except Exception as e: diff --git a/pyrit/cli/_banner.py b/pyrit/cli/_banner.py index a76d286b27..0c8d4719eb 100644 --- a/pyrit/cli/_banner.py +++ b/pyrit/cli/_banner.py @@ -568,11 +568,11 @@ def _render_line_with_segments( result: list[str] = [] current_role: Optional[ColorRole] = None for pos, ch in enumerate(line): - role = char_roles[pos] - if role != current_role: - color = _get_color(role, theme) if role else reset + char_role = char_roles[pos] + if char_role != current_role: + color = _get_color(char_role, theme) if char_role else reset result.append(color) - current_role = role + current_role = char_role result.append(ch) result.append(reset) return "".join(result) diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py index b4bb584410..12f6257dbb 100644 --- a/pyrit/cli/_cli_args.py +++ b/pyrit/cli/_cli_args.py @@ -482,29 +482,29 @@ def _parse_shell_arguments(*, parts: list[str], arg_specs: list[_ArgSpec]) -> di i = 0 while i < len(parts): token = parts[i] - spec = flag_to_spec.get(token) + matched_spec: _ArgSpec | None = flag_to_spec.get(token) - if spec is None: + if matched_spec is None: valid = sorted(flag_to_spec.keys()) raise ValueError(f"Unknown argument: {token}. Valid arguments: {', '.join(valid)}") i += 1 - if spec.multi_value: + if matched_spec.multi_value: values: list[Any] = [] # Collect values until the next flag (whether valid or invalid) while i < len(parts) and not (parts[i].startswith("--") or parts[i] in flag_to_spec): - item = spec.parser(parts[i]) if spec.parser else parts[i] + item = matched_spec.parser(parts[i]) if matched_spec.parser else parts[i] values.append(item) i += 1 if len(values) == 0: - raise ValueError(f"{spec.flags[0]} requires at least one value") - result[spec.result_key] = values + raise ValueError(f"{matched_spec.flags[0]} requires at least one value") + result[matched_spec.result_key] = values else: if i >= len(parts): - raise ValueError(f"{spec.flags[0]} requires a value") + raise ValueError(f"{matched_spec.flags[0]} requires a value") raw = parts[i] - result[spec.result_key] = spec.parser(raw) if spec.parser else raw + result[matched_spec.result_key] = matched_spec.parser(raw) if matched_spec.parser else raw i += 1 return result diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index bc9519052a..aed63cb96f 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -56,7 +56,7 @@ class termcolor: # type: ignore[no-redef] # noqa: N801 """Dummy termcolor fallback for colored printing if termcolor is not installed.""" @staticmethod - def cprint(text: str, color: str = None, attrs: list = None) -> None: # type: ignore[type-arg] + def cprint(text: str, color: str | None = None, attrs: list[Any] | None = None) -> None: """Print text without color.""" print(text) @@ -249,12 +249,14 @@ def scenario_registry(self) -> ScenarioRegistry: Raises: RuntimeError: If initialize_async() has not been called. + ValueError: If the scenario registry is not initialized. """ if not self._initialized: raise RuntimeError( "FrontendCore not initialized. Call 'await context.initialize_async()' before accessing registries." ) - assert self._scenario_registry is not None + if self._scenario_registry is None: + raise ValueError("self._scenario_registry is not initialized") return self._scenario_registry @property @@ -264,12 +266,14 @@ def initializer_registry(self) -> InitializerRegistry: Raises: RuntimeError: If initialize_async() has not been called. + ValueError: If the initializer registry is not initialized. """ if not self._initialized: raise RuntimeError( "FrontendCore not initialized. Call 'await context.initialize_async()' before accessing registries." ) - assert self._initializer_registry is not None + if self._initializer_registry is None: + raise ValueError("self._initializer_registry is not initialized") return self._initializer_registry diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index ae38edcde8..7fd066ad7c 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -19,6 +19,8 @@ from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: + import types + from pyrit.cli import frontend_core from pyrit.models.scenario_result import ScenarioResult @@ -119,7 +121,7 @@ def __init__( new_item="PyRITShell(database=..., log_level=..., ...)", removed_in="0.14.0", ) - self._deprecated_context = context + self._deprecated_context: frontend_core.FrontendCore | None = context else: self._deprecated_context = None @@ -127,8 +129,9 @@ def __init__( self._scenario_history: list[tuple[str, ScenarioResult]] = [] # Set by the background thread after importing frontend_core. - self.context: Optional[frontend_core.FrontendCore] = None - self.default_log_level: Optional[int] = None + self._fc: types.ModuleType | None = None + self.context: frontend_core.FrontendCore | None = None + self.default_log_level: int | None = None # Initialize PyRIT in background thread for faster startup. self._init_thread = threading.Thread(target=self._background_init, daemon=True) @@ -159,12 +162,19 @@ def _raise_init_error(self) -> None: raise self._init_error def _ensure_initialized(self) -> None: - """Wait for initialization to complete if not already done.""" + """ + Wait for initialization to complete if not already done. + + Raises: + RuntimeError: If frontend core initialization failed or is not complete. + """ if not self._init_complete.is_set(): print("Waiting for PyRIT initialization to complete...") sys.stdout.flush() self._init_complete.wait() self._raise_init_error() + if self._fc is None or self.context is None: + raise RuntimeError("Frontend core not initialized") def cmdloop(self, intro: Optional[str] = None) -> None: """Override cmdloop to play animated banner before starting the REPL.""" @@ -188,22 +198,36 @@ def cmdloop(self, intro: Optional[str] = None) -> None: super().cmdloop(intro=self.intro) def do_list_scenarios(self, arg: str) -> None: - """List all available scenarios.""" + """ + List all available scenarios. + + Raises: + RuntimeError: If initialization has not completed. + """ if arg.strip(): print(f"Error: list-scenarios does not accept arguments, got: {arg.strip()}") return self._ensure_initialized() + if self._fc is None or self.context is None: + raise RuntimeError("Frontend core not initialized") try: asyncio.run(self._fc.print_scenarios_list_async(context=self.context)) except Exception as e: print(f"Error listing scenarios: {e}") def do_list_initializers(self, arg: str) -> None: - """List all available initializers.""" + """ + List all available initializers. + + Raises: + RuntimeError: If initialization has not completed. + """ if arg.strip(): print(f"Error: list-initializers does not accept arguments, got: {arg.strip()}") return self._ensure_initialized() + if self._fc is None or self.context is None: + raise RuntimeError("Frontend core not initialized") try: asyncio.run(self._fc.print_initializers_list_async(context=self.context)) except Exception as e: @@ -225,8 +249,13 @@ def do_list_targets(self, arg: str) -> None: Examples: list-targets --initializers target list-targets --initializers target:tags=default,scorer + + Raises: + RuntimeError: If initialization has not completed. """ self._ensure_initialized() + if self._fc is None or self.context is None: + raise RuntimeError("Frontend core not initialized") try: list_targets_context = self.context if arg.strip(): @@ -289,8 +318,13 @@ def do_run(self, line: str) -> None: --target is required for every run. Initializers can be specified per-run or configured in .pyrit_conf. Database and env-files are configured via the config file. + + Raises: + RuntimeError: If initialization has not completed. """ self._ensure_initialized() + if self._fc is None or self.context is None: + raise RuntimeError("Frontend core not initialized") if not line.strip(): print("Error: Specify a scenario name") diff --git a/pyrit/common/data_url_converter.py b/pyrit/common/data_url_converter.py index 4fd6bb3b16..64d3ea97fb 100644 --- a/pyrit/common/data_url_converter.py +++ b/pyrit/common/data_url_converter.py @@ -23,7 +23,7 @@ async def convert_local_image_to_data_url(image_path: str) -> str: str: A string containing the MIME type and the base64-encoded data of the image, formatted as a data URL. """ ext = DataTypeSerializer.get_extension(image_path) - if ext.lower() not in AZURE_OPENAI_GPT4O_SUPPORTED_IMAGE_FORMATS: + if not ext or ext.lower() not in AZURE_OPENAI_GPT4O_SUPPORTED_IMAGE_FORMATS: raise ValueError( f"Unsupported image format: {ext}. Supported formats are: {AZURE_OPENAI_GPT4O_SUPPORTED_IMAGE_FORMATS}" ) diff --git a/pyrit/common/display_response.py b/pyrit/common/display_response.py index 7341df8376..6a97af39cc 100644 --- a/pyrit/common/display_response.py +++ b/pyrit/common/display_response.py @@ -1,52 +1,56 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import io -import logging - -from PIL import Image - -from pyrit.common.notebook_utils import is_in_ipython_session -from pyrit.models import AzureBlobStorageIO, DiskStorageIO, MessagePiece - -logger = logging.getLogger(__name__) - - -async def display_image_response(response_piece: MessagePiece) -> None: - """ - Display response images if running in notebook environment. - - Args: - response_piece (MessagePiece): The response piece to display. - """ - from pyrit.memory import CentralMemory - - memory = CentralMemory.get_memory_instance() - if ( - response_piece.response_error == "none" - and response_piece.converted_value_data_type == "image_path" - and is_in_ipython_session() - ): - image_location = response_piece.converted_value - - try: - image_bytes = await memory.results_storage_io.read_file(image_location) - except Exception as e: - if isinstance(memory.results_storage_io, AzureBlobStorageIO): - try: - # Fallback to reading from disk if the storage IO fails - image_bytes = await DiskStorageIO().read_file(image_location) - except Exception as exc: - logger.error(f"Failed to read image from {image_location}. Full exception: {str(exc)}") - return - else: - logger.error(f"Failed to read image from {image_location}. Full exception: {str(e)}") - return - - image_stream = io.BytesIO(image_bytes) - image = Image.open(image_stream) - - # Jupyter built-in display function only works in notebooks. - display(image) # type: ignore[name-defined] # noqa: F821 - if response_piece.response_error == "blocked": - logger.info("---\nContent blocked, cannot show a response.\n---") +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import io +import logging + +from PIL import Image + +from pyrit.common.notebook_utils import is_in_ipython_session +from pyrit.memory import CentralMemory +from pyrit.models import AzureBlobStorageIO, DiskStorageIO, MessagePiece + +logger = logging.getLogger(__name__) + + +async def display_image_response(response_piece: MessagePiece) -> None: + """ + Display response images if running in notebook environment. + + Args: + response_piece (MessagePiece): The response piece to display. + + Raises: + RuntimeError: If storage IO is not initialized. + """ + memory = CentralMemory.get_memory_instance() + if ( + response_piece.response_error == "none" + and response_piece.converted_value_data_type == "image_path" + and is_in_ipython_session() + ): + image_location = response_piece.converted_value + + try: + if memory.results_storage_io is None: + raise RuntimeError("Storage IO not initialized") + image_bytes = await memory.results_storage_io.read_file(image_location) + except Exception as e: + if isinstance(memory.results_storage_io, AzureBlobStorageIO): + try: + # Fallback to reading from disk if the storage IO fails + image_bytes = await DiskStorageIO().read_file(image_location) + except Exception as exc: + logger.error(f"Failed to read image from {image_location}. Full exception: {str(exc)}") + return + else: + logger.error(f"Failed to read image from {image_location}. Full exception: {str(e)}") + return + + image_stream = io.BytesIO(image_bytes) + image = Image.open(image_stream) + + # Jupyter built-in display function only works in notebooks. + display(image) # type: ignore[name-defined] # noqa: F821 + if response_piece.response_error == "blocked": + logger.info("---\nContent blocked, cannot show a response.\n---") diff --git a/pyrit/common/download_hf_model.py b/pyrit/common/download_hf_model.py index ad420a6811..10095b526b 100644 --- a/pyrit/common/download_hf_model.py +++ b/pyrit/common/download_hf_model.py @@ -25,7 +25,7 @@ def get_available_files(model_id: str, token: str) -> list[str]: api = HfApi() try: model_info = api.model_info(model_id, token=token) - available_files = [file.rfilename for file in model_info.siblings] + available_files = [file.rfilename for file in (model_info.siblings or [])] # Perform simple validation: raise a ValueError if no files are available if not len(available_files): diff --git a/pyrit/common/net_utility.py b/pyrit/common/net_utility.py index a1a5ebe4d5..eb75f5616e 100644 --- a/pyrit/common/net_utility.py +++ b/pyrit/common/net_utility.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Any, Literal, Optional, overload +from typing import Any, Literal, Optional, cast, overload from urllib.parse import parse_qs, urlparse, urlunparse import httpx @@ -10,18 +10,18 @@ @overload def get_httpx_client( - use_async: Literal[True], debug: bool = False, **httpx_client_kwargs: Optional[Any] + use_async: Literal[True], debug: bool = False, **httpx_client_kwargs: Any ) -> httpx.AsyncClient: ... @overload def get_httpx_client( - use_async: Literal[False] = False, debug: bool = False, **httpx_client_kwargs: Optional[Any] + use_async: Literal[False] = False, debug: bool = False, **httpx_client_kwargs: Any ) -> httpx.Client: ... def get_httpx_client( - use_async: bool = False, debug: bool = False, **httpx_client_kwargs: Optional[Any] + use_async: bool = False, debug: bool = False, **httpx_client_kwargs: Any ) -> httpx.Client | httpx.AsyncClient: """ Get the httpx client for making requests. @@ -32,10 +32,10 @@ def get_httpx_client( client_class = httpx.AsyncClient if use_async else httpx.Client proxy = "http://localhost:8080" if debug else None - proxy = httpx_client_kwargs.pop("proxy", proxy) - verify_certs = httpx_client_kwargs.pop("verify", not debug) + proxy = cast("str | None", httpx_client_kwargs.pop("proxy", proxy)) + verify_certs = cast("bool", httpx_client_kwargs.pop("verify", not debug)) # fun notes; httpx default is 5 seconds, httpclient is 100, urllib in indefinite - timeout = httpx_client_kwargs.pop("timeout", 60.0) + timeout = cast("float", httpx_client_kwargs.pop("timeout", 60.0)) return client_class(proxy=proxy, verify=verify_certs, timeout=timeout, **httpx_client_kwargs) @@ -92,7 +92,7 @@ async def make_request_and_raise_if_error_async( request_body: Optional[dict[str, object]] = None, files: Optional[dict[str, tuple[str, bytes, str]]] = None, headers: Optional[dict[str, str]] = None, - **httpx_client_kwargs: Optional[Any], + **httpx_client_kwargs: Any, ) -> httpx.Response: """ Make a request and raise an exception if it fails. diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py index d2065d0638..b9b1b26768 100644 --- a/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py @@ -221,14 +221,24 @@ async def _fetch_and_save_image_async(self, image_url: str, behavior_id: str) -> Returns: Local path to the saved image. + + Raises: + RuntimeError: If the serializer memory is not properly configured. """ filename = f"harmbench_{behavior_id}.png" serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png") # Return existing path if image already exists for this BehaviorID - serializer.value = str(serializer._memory.results_path + serializer.data_sub_directory + f"/{filename}") + results_path = serializer._memory.results_path + results_storage_io = serializer._memory.results_storage_io + if not results_path or results_storage_io is None: + raise RuntimeError( + "[HarmBench-Multimodal] Serializer memory is not properly configured: " + "results_path and results_storage_io must be set." + ) + serializer.value = str(results_path + serializer.data_sub_directory + f"/{filename}") try: - if await serializer._memory.results_storage_io.path_exists(serializer.value): + if await results_storage_io.path_exists(serializer.value): return serializer.value except Exception as e: logger.warning(f"[HarmBench-Multimodal] Failed to check if image for {behavior_id} exists in cache: {e}") diff --git a/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py b/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py index e52a8b4e4f..2f767fe429 100644 --- a/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py @@ -324,9 +324,14 @@ async def _fetch_and_save_image_async(self, image_url: str, example_id: str) -> serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png") # Return existing path if image already exists - serializer.value = str(serializer._memory.results_path + serializer.data_sub_directory + f"/{filename}") + results_path = (serializer._memory.results_path if serializer._memory is not None else None) or "" + serializer.value = str(results_path + serializer.data_sub_directory + f"/{filename}") try: - if await serializer._memory.results_storage_io.path_exists(serializer.value): + if ( + serializer._memory is not None + and serializer._memory.results_storage_io is not None + and await serializer._memory.results_storage_io.path_exists(serializer.value) + ): return serializer.value except Exception as e: logger.warning(f"[VisualLeakBench] Failed to check if image {example_id} exists in cache: {e}") diff --git a/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py b/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py index 59aad6baea..7ad0a1b470 100644 --- a/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py @@ -193,12 +193,12 @@ async def _build_prompt_pair_async(self, example: dict[str, str]) -> list[SeedPr Raises: Exception: If the image cannot be fetched. """ - text = example.get("prompt") - image_url = example.get("web_path") + text = example.get("prompt", "") + image_url = example.get("web_path", "") text_grade = example.get("consensus_text_grade", "").lower() image_grade = example.get("image_grade", "").lower() combined_grade = example.get("consensus_combined_grade", "").lower() - combined_category = example.get("combined_category") + combined_category = example.get("combined_category", "") group_id = uuid.uuid4() local_image_path = await self._fetch_and_save_image_async(image_url, str(group_id)) @@ -248,14 +248,21 @@ async def _fetch_and_save_image_async(self, image_url: str, group_id: str) -> st Returns: Local path to the saved image. + + Raises: + RuntimeError: If the serializer memory is not properly configured. """ filename = f"ml_vlsu_{group_id}.png" serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png") # Return existing path if image already exists - serializer.value = str(serializer._memory.results_path + serializer.data_sub_directory + f"/{filename}") + results_path = serializer._memory.results_path + results_storage_io = serializer._memory.results_storage_io + if not results_path or results_storage_io is None: + raise RuntimeError("[ML-VLSU] Serializer memory is not properly configured.") + serializer.value = str(results_path + serializer.data_sub_directory + f"/{filename}") try: - if await serializer._memory.results_storage_io.path_exists(serializer.value): + if await results_storage_io.path_exists(serializer.value): return serializer.value except Exception as e: logger.warning(f"[ML-VLSU] Failed to check if image for {group_id} exists in cache: {e}") diff --git a/pyrit/embedding/openai_text_embedding.py b/pyrit/embedding/openai_text_embedding.py index f6b51a10b8..5efbb69107 100644 --- a/pyrit/embedding/openai_text_embedding.py +++ b/pyrit/embedding/openai_text_embedding.py @@ -8,6 +8,7 @@ import tenacity from openai import AsyncOpenAI +from pyrit.auth import ensure_async_token_provider from pyrit.common import default_values from pyrit.models import ( EmbeddingData, @@ -60,10 +61,10 @@ def __init__( env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key ) - # Create async client - type: ignore needed because get_required_value returns str - # but api_key parameter accepts str | Callable[[], str | Awaitable[str]] + # Wrap sync token providers for async compatibility; AsyncOpenAI accepts str or async callable + resolved_api_key = ensure_async_token_provider(api_key) self._async_client = AsyncOpenAI( - api_key=api_key, # type: ignore[arg-type] + api_key=resolved_api_key, base_url=endpoint, ) diff --git a/pyrit/executor/attack/core/attack_parameters.py b/pyrit/executor/attack/core/attack_parameters.py index 95635cde3b..53bd34f6f5 100644 --- a/pyrit/executor/attack/core/attack_parameters.py +++ b/pyrit/executor/attack/core/attack_parameters.py @@ -123,7 +123,8 @@ async def from_seed_group_async( seed_group.validate() # SeedAttackGroup validates in __init__ that objective is set - assert seed_group.objective is not None + if seed_group.objective is None: + raise ValueError("seed_group.objective is not initialized") # Build params dict, only including fields this class accepts params: dict[str, Any] = {} diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index e92bd1cf67..7ea7f927b7 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -166,7 +166,7 @@ class TAPAttackResult(AttackResult): @property def tree_visualization(self) -> Optional[Tree]: """Get the tree visualization from metadata.""" - return cast("Optional[Tree]", self.metadata.get("tree_visualization", None)) + return self.metadata.get("tree_visualization", None) @tree_visualization.setter def tree_visualization(self, value: Tree) -> None: @@ -1354,7 +1354,9 @@ def __init__( else: # Convert AttackScoringConfig to TAPAttackScoringConfig objective_scorer = attack_scoring_config.objective_scorer - if objective_scorer is not None and not isinstance(objective_scorer, FloatScaleThresholdScorer): + if objective_scorer is None: + raise ValueError("objective_scorer is required") + if not isinstance(objective_scorer, FloatScaleThresholdScorer): raise ValueError( "TAP attack requires a FloatScaleThresholdScorer for objective_scorer. " "Please wrap your scorer in FloatScaleThresholdScorer with an appropriate threshold." diff --git a/pyrit/executor/attack/printer/markdown_printer.py b/pyrit/executor/attack/printer/markdown_printer.py index e50446bb38..5946ce985c 100644 --- a/pyrit/executor/attack/printer/markdown_printer.py +++ b/pyrit/executor/attack/printer/markdown_printer.py @@ -487,9 +487,8 @@ async def _get_summary_markdown_async(self, result: AttackResult) -> list[str]: markdown_lines.append("|-------|-------|") markdown_lines.append(f"| **Objective** | {result.objective} |") - attack_type = ( - result.get_attack_strategy_identifier().class_name if result.get_attack_strategy_identifier() else "Unknown" - ) + _strategy_id = result.get_attack_strategy_identifier() + attack_type = _strategy_id.class_name if _strategy_id is not None else "Unknown" markdown_lines.append(f"| **Attack Type** | `{attack_type}` |") markdown_lines.append(f"| **Conversation ID** | `{result.conversation_id}` |") diff --git a/pyrit/executor/promptgen/anecdoctor.py b/pyrit/executor/promptgen/anecdoctor.py index 208c4040d7..f30dcc5c43 100644 --- a/pyrit/executor/promptgen/anecdoctor.py +++ b/pyrit/executor/promptgen/anecdoctor.py @@ -356,9 +356,11 @@ async def _extract_knowledge_graph_async(self, *, context: AnecdoctorContext) -> Raises: RuntimeError: If knowledge graph extraction fails. + ValueError: If the processing model is not initialized. """ # Processing model is guaranteed to exist when this method is called - assert self._processing_model is not None + if self._processing_model is None: + raise ValueError("self._processing_model is not initialized") self._logger.debug("Extracting knowledge graph from evaluation data") diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer.py b/pyrit/executor/promptgen/fuzzer/fuzzer.py index 2b4f7a7254..1833bb0f9a 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer.py @@ -1015,13 +1015,19 @@ def _create_normalizer_requests(self, prompts: list[str]) -> list[NormalizerRequ Returns: List of normalizer requests. + + Raises: + ValueError: If a seed group contains no message. """ requests: list[NormalizerRequest] = [] for prompt in prompts: seed_group = SeedGroup(seeds=[SeedPrompt(value=prompt, data_type="text")]) + _msg = seed_group.next_message + if _msg is None: + raise ValueError("No message in seed group") request = NormalizerRequest( - message=seed_group.next_message, + message=_msg, request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, ) diff --git a/pyrit/executor/workflow/xpia.py b/pyrit/executor/workflow/xpia.py index 2dc021b497..1cb22a5773 100644 --- a/pyrit/executor/workflow/xpia.py +++ b/pyrit/executor/workflow/xpia.py @@ -356,8 +356,16 @@ async def _execute_processing_async(self, *, context: XPIAContext) -> str: Returns: str: The response from the processing target. + + Raises: + ValueError: If the processing callback is not set. + RuntimeError: If memory is not initialized. """ + if context.processing_callback is None: + raise ValueError("processing_callback is not set") processing_response = await context.processing_callback() + if self._memory is None: + raise RuntimeError("Memory not initialized") self._memory.add_message_to_memory( request=Message( message_pieces=[ @@ -560,7 +568,8 @@ async def _setup_async(self, *, context: XPIAContext) -> None: # Create the processing callback using the test context async def process_async() -> str: # processing_prompt is validated to be non-None in _validate_context - assert context.processing_prompt is not None + if context.processing_prompt is None: + raise RuntimeError("context.processing_prompt is not initialized") response = await self._prompt_normalizer.send_prompt_async( message=context.processing_prompt, target=self._processing_target, diff --git a/pyrit/identifiers/component_identifier.py b/pyrit/identifiers/component_identifier.py index 39363c96d7..6c43fb6cdb 100644 --- a/pyrit/identifiers/component_identifier.py +++ b/pyrit/identifiers/component_identifier.py @@ -182,7 +182,12 @@ def short_hash(self) -> str: Returns: str: First 8 hex characters of the SHA256 hash. + + Raises: + RuntimeError: If the hash was not set by __post_init__. """ + if self.hash is None: + raise RuntimeError("hash should be set by __post_init__") return self.hash[:8] @property diff --git a/pyrit/identifiers/evaluation_identifier.py b/pyrit/identifiers/evaluation_identifier.py index f6d04ce089..448b5352ba 100644 --- a/pyrit/identifiers/evaluation_identifier.py +++ b/pyrit/identifiers/evaluation_identifier.py @@ -144,8 +144,13 @@ def compute_eval_hash( Returns: str: A hex-encoded SHA256 hash suitable for eval registry keying. + + Raises: + RuntimeError: If the identifier's hash is None and child_eval_rules is empty. """ if not child_eval_rules: + if identifier.hash is None: + raise RuntimeError("hash should be set by __post_init__") return identifier.hash eval_dict = _build_eval_dict( diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 207a7f7f98..d93d2bd4a0 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -142,10 +142,15 @@ def _create_auth_token(self) -> None: def _refresh_token_if_needed(self) -> None: """ Refresh the access token if it is close to expiry (within 5 minutes). + + Raises: + RuntimeError: If auth token expiry was not initialized. """ - if datetime.now(timezone.utc) >= datetime.fromtimestamp(self._auth_token_expiry, tz=timezone.utc) - timedelta( - minutes=5 - ): + if self._auth_token_expiry is None: + raise RuntimeError("Auth token expiry not initialized; call _create_auth_token() first") + if datetime.now(timezone.utc) >= datetime.fromtimestamp( + float(self._auth_token_expiry), tz=timezone.utc + ) - timedelta(minutes=5): logger.info("Refreshing Microsoft Entra ID access token...") self._create_auth_token() @@ -201,6 +206,8 @@ def provide_token(_dialect: Any, _conn_rec: Any, cargs: list[Any], cparams: dict cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "") # encode the token + if self._auth_token is None: + raise RuntimeError("Azure auth token is not initialized") azure_token = self._auth_token.token azure_token_bytes = azure_token.encode("utf-16-le") packed_azure_token = struct.pack(f" None: Raises: Exception: If there's an issue creating the tables in the database. + RuntimeError: If the engine is not initialized. """ try: # Using the 'checkfirst=True' parameter to avoid attempting to recreate existing tables + if self.engine is None: + raise RuntimeError("Engine is not initialized") Base.metadata.create_all(self.engine, checkfirst=True) except Exception as e: logger.exception(f"Error during table creation: {e}") @@ -790,8 +800,16 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict raise def reset_database(self) -> None: - """Drop and recreate existing tables.""" + """ + Drop and recreate existing tables. + + Raises: + RuntimeError: If the engine is not initialized. + """ # Drop all existing tables + if self.engine is None: + raise RuntimeError("Engine is not initialized") + Base.metadata.drop_all(self.engine) # Recreate the tables Base.metadata.create_all(self.engine, checkfirst=True) diff --git a/pyrit/memory/central_memory.py b/pyrit/memory/central_memory.py index a933e73107..675d61fe3c 100644 --- a/pyrit/memory/central_memory.py +++ b/pyrit/memory/central_memory.py @@ -14,7 +14,7 @@ class CentralMemory: The provided memory instance will be reused for future calls. """ - _memory_instance: MemoryInterface = None + _memory_instance: MemoryInterface | None = None @classmethod def set_memory_instance(cls, passed_memory: MemoryInterface) -> None: diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index a0abed1476..7e7733e97c 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -69,10 +69,10 @@ class MemoryInterface(abc.ABC): such as files, databases, or cloud storage services. """ - memory_embedding: MemoryEmbedding = None - results_storage_io: StorageIO = None - results_path: str = None - engine: Engine = None + memory_embedding: MemoryEmbedding | None = None + results_storage_io: StorageIO | None = None + results_path: str | None = None + engine: Engine | None = None @staticmethod def _uid() -> str: @@ -1099,7 +1099,7 @@ async def _serialize_seed_value(self, prompt: Seed) -> str: audio_bytes = await serializer.read_data() await serializer.save_data(data=audio_bytes) serialized_prompt_value = str(serializer.value) - return serialized_prompt_value + return serialized_prompt_value or "" async def add_seeds_to_memory_async(self, *, seeds: Sequence[Seed], added_by: Optional[str] = None) -> None: """ @@ -1136,8 +1136,9 @@ async def add_seeds_to_memory_async(self, *, seeds: Sequence[Seed], added_by: Op await prompt.set_sha256_value_async() - existing = self.get_seeds(value_sha256=[prompt.value_sha256], dataset_name=prompt.dataset_name) - if not existing: + if prompt.value_sha256 and not self.get_seeds( + value_sha256=[prompt.value_sha256], dataset_name=prompt.dataset_name + ): entries.append(SeedEntry(entry=prompt)) self._insert_entries(entries=entries) @@ -1869,8 +1870,15 @@ def get_scenario_results( raise def print_schema(self) -> None: - """Print the schema of all tables in the database.""" + """ + Print the schema of all tables in the database. + + Raises: + RuntimeError: If the engine is not initialized. + """ metadata = MetaData() + if self.engine is None: + raise RuntimeError("Engine is not initialized") metadata.reflect(bind=self.engine) for table_name in metadata.tables: diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index b34c906af6..6400c049c2 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -400,7 +400,7 @@ def __init__(self, *, entry: Score): self.score_type = entry.score_type self.score_category = entry.score_category self.score_rationale = entry.score_rationale - self.score_metadata = entry.score_metadata + self.score_metadata = entry.score_metadata # type: ignore[assignment] # Normalize to ComponentIdentifier (handles dict with deprecation warning) then convert to dict for JSON storage normalized_scorer = ComponentIdentifier.normalize(entry.scorer_class_identifier) # Ensure eval_hash is set before truncation so it survives the DB round-trip @@ -441,7 +441,7 @@ def get_score(self) -> Score: score_category=self.score_category, score_rationale=self.score_rationale, score_metadata=self.score_metadata, - scorer_class_identifier=scorer_identifier, + scorer_class_identifier=scorer_identifier, # type: ignore[arg-type] message_piece_id=self.prompt_request_response_id, timestamp=_ensure_utc(self.timestamp), objective=self.objective, @@ -593,7 +593,7 @@ def __init__(self, *, entry: Seed): self.source = entry.source self.date_added = entry.date_added self.added_by = entry.added_by - self.prompt_metadata = entry.metadata + self.prompt_metadata = entry.metadata # type: ignore[assignment] self.prompt_group_id = entry.prompt_group_id self.seed_type = seed_type @@ -601,11 +601,11 @@ def __init__(self, *, entry: Seed): if isinstance(entry, SeedPrompt): self.parameters = list(entry.parameters) if entry.parameters else None self.sequence = entry.sequence - self.role = entry.role + self.role = entry.role # type: ignore[assignment] else: self.parameters = None self.sequence = None - self.role = None + self.role = None # type: ignore[assignment] def get_seed(self) -> Seed: """ @@ -673,7 +673,7 @@ def get_seed(self) -> Seed: metadata=self.prompt_metadata, parameters=self.parameters, prompt_group_id=self.prompt_group_id, - sequence=self.sequence, + sequence=self.sequence or 0, role=self.role, ) @@ -1037,7 +1037,7 @@ def get_scenario_result(self) -> ScenarioResult: scenario_identifier=scenario_identifier, objective_target_identifier=target_identifier, attack_results=attack_results, - objective_scorer_identifier=scorer_identifier, + objective_scorer_identifier=scorer_identifier, # type: ignore[arg-type] scenario_run_state=self.scenario_run_state, labels=self.labels, number_tries=self.number_tries, diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index bd376d67cd..a4039c1b76 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -134,9 +134,12 @@ def _create_tables_if_not_exist(self) -> None: Raises: Exception: If there's an issue creating the tables in the database. + RuntimeError: If the engine is not initialized. """ try: # Using the 'checkfirst=True' parameter to avoid attempting to recreate existing tables + if self.engine is None: + raise RuntimeError("Engine is not initialized") Base.metadata.create_all(self.engine, checkfirst=True) except Exception as e: logger.exception(f"Error during table creation: {e}") @@ -440,7 +443,13 @@ def get_session(self) -> Session: def reset_database(self) -> None: """ Drop and recreates all tables in the database. + + Raises: + RuntimeError: If the engine is not initialized. """ + if self.engine is None: + raise RuntimeError("Engine is not initialized") + Base.metadata.drop_all(self.engine) Base.metadata.create_all(self.engine) diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index 0ebfa37946..2fa3bfc0a2 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -164,7 +164,7 @@ async def _convert_audio_to_input_audio(self, audio_path: str) -> dict[str, Any] ValueError: If the audio format is not supported. FileNotFoundError: If the audio file does not exist. """ - ext = DataTypeSerializer.get_extension(audio_path).lower() + ext = (DataTypeSerializer.get_extension(audio_path) or "").lower() if ext not in SUPPORTED_AUDIO_FORMATS: raise ValueError( f"Unsupported audio format: {ext}. Supported formats are: {list(SUPPORTED_AUDIO_FORMATS.keys())}" diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index c2004160fb..86322e2300 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -96,7 +96,7 @@ class DataTypeSerializer(abc.ABC): data_sub_directory: str file_extension: str - _file_path: Union[Path, str] = None + _file_path: Union[Path, str] | None = None @property def _memory(self) -> MemoryInterface: @@ -113,11 +113,14 @@ def _get_storage_io(self) -> StorageIO: Raises: ValueError: If the Azure Storage URL is detected but the datasets storage handle is not set. + RuntimeError: If results_storage_io is not configured but Azure storage URL was detected. """ if self._is_azure_storage_url(self.value): # Scenarios where a user utilizes an in-memory DuckDB but also needs to interact # with an Azure Storage Account, ex., XPIAWorkflow. + if self._memory.results_storage_io is None: + raise RuntimeError("results_storage_io is not configured but Azure storage URL was detected") return self._memory.results_storage_io return DiskStorageIO() @@ -139,12 +142,16 @@ async def save_data(self, data: bytes, output_filename: Optional[str] = None) -> data: bytes: The data to be saved. output_filename (optional, str): filename to store data as. Defaults to UUID if not provided + Raises: + RuntimeError: If storage IO is not initialized. """ file_path = await self.get_data_filename(file_name=output_filename) + if self._memory.results_storage_io is None: + raise RuntimeError("Storage IO not initialized") await self._memory.results_storage_io.write_file(file_path, data) self.value = str(file_path) - async def save_b64_image(self, data: str | bytes, output_filename: str = None) -> None: + async def save_b64_image(self, data: str | bytes, output_filename: str | None = None) -> None: """ Save a base64-encoded image to storage. @@ -152,9 +159,13 @@ async def save_b64_image(self, data: str | bytes, output_filename: str = None) - data: string or bytes with base64 data output_filename (optional, str): filename to store image as. Defaults to UUID if not provided + Raises: + RuntimeError: If storage IO is not initialized. """ file_path = await self.get_data_filename(file_name=output_filename) image_bytes = base64.b64decode(data) + if self._memory.results_storage_io is None: + raise RuntimeError("Storage IO not initialized") await self._memory.results_storage_io.write_file(file_path, image_bytes) self.value = str(file_path) @@ -176,6 +187,8 @@ async def save_formatted_audio( sample_width (optional, int): sample width in bytes. Defaults to 2 sample_rate (optional, int): sample rate in Hz. Defaults to 16000 + Raises: + RuntimeError: If storage IO is not initialized. """ file_path = await self.get_data_filename(file_name=output_filename) @@ -190,6 +203,8 @@ async def save_formatted_audio( async with aiofiles.open(local_temp_path, "rb") as f: audio_data = await f.read() + if self._memory.results_storage_io is None: + raise RuntimeError("self._memory.results_storage_io is not initialized") await self._memory.results_storage_io.write_file(file_path, audio_data) os.remove(local_temp_path) @@ -253,7 +268,7 @@ async def get_sha256(self) -> str: ValueError: If in-memory data cannot be converted to bytes. """ - input_bytes: bytes = None + input_bytes: bytes | None = None if self.data_on_disk(): storage_io = self._get_storage_io() @@ -297,7 +312,12 @@ async def get_data_filename(self, file_name: Optional[str] = None) -> Union[Path raise RuntimeError("Data sub directory not set") ticks = int(time.time() * 1_000_000) - results_path = self._memory.results_path + if self._memory.results_path: + results_path = str(self._memory.results_path) + else: + from pyrit.common.path import DB_DATA_PATH + + results_path = str(DB_DATA_PATH) file_name = file_name if file_name else str(ticks) if self._is_azure_storage_url(results_path): @@ -305,6 +325,8 @@ async def get_data_filename(self, file_name: Optional[str] = None) -> Union[Path self._file_path = full_data_directory_path + f"/{file_name}.{self.file_extension}" else: full_data_directory_path = results_path + self.data_sub_directory + if self._memory.results_storage_io is None: + raise RuntimeError("self._memory.results_storage_io is not initialized") await self._memory.results_storage_io.create_directory_if_not_exists(Path(full_data_directory_path)) self._file_path = Path(full_data_directory_path, f"{file_name}.{self.file_extension}") diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index 91d01032bf..25f887342d 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -263,7 +263,7 @@ def set_piece_not_in_database(self) -> None: This is needed when we're scoring prompts or other things that have not been sent by PyRIT """ - self.id = None + self.id = None # type: ignore[assignment] def to_dict(self) -> dict[str, object]: """ diff --git a/pyrit/models/seeds/seed_attack_group.py b/pyrit/models/seeds/seed_attack_group.py index 1ef38ee82f..30b00e1100 100644 --- a/pyrit/models/seeds/seed_attack_group.py +++ b/pyrit/models/seeds/seed_attack_group.py @@ -93,9 +93,12 @@ def objective(self) -> SeedObjective: Returns: The SeedObjective for this attack group. + Raises: + ValueError: If the attack group does not have an objective. """ obj = self._get_objective() - assert obj is not None, "SeedAttackGroup should always have an objective" + if obj is None: + raise ValueError("SeedAttackGroup should always have an objective") return obj def with_technique(self, *, technique: SeedAttackTechniqueGroup) -> SeedAttackGroup: diff --git a/pyrit/models/seeds/seed_dataset.py b/pyrit/models/seeds/seed_dataset.py index 3e55342fb4..36749a382f 100644 --- a/pyrit/models/seeds/seed_dataset.py +++ b/pyrit/models/seeds/seed_dataset.py @@ -182,14 +182,14 @@ def __init__( } if effective_type == "simulated_conversation": - self.seeds.append( - SeedSimulatedConversation( - **base_params, - num_turns=p.get("num_turns", 3), - adversarial_chat_system_prompt_path=p.get("adversarial_chat_system_prompt_path"), - simulated_target_system_prompt_path=p.get("simulated_target_system_prompt_path"), - ) - ) + _adv_path = p.get("adversarial_chat_system_prompt_path") + _sim_path = p.get("simulated_target_system_prompt_path") + _sc_kwargs: dict[str, Any] = {**base_params, "num_turns": p.get("num_turns", 3)} + if _adv_path is not None: + _sc_kwargs["adversarial_chat_system_prompt_path"] = str(_adv_path) + if _sim_path is not None: + _sc_kwargs["simulated_target_system_prompt_path"] = str(_sim_path) + self.seeds.append(SeedSimulatedConversation(**_sc_kwargs)) elif effective_type == "objective": # SeedObjective inherits data_type="text" from base Seed property base_params["value"] = p["value"] diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index dca850d417..a2a733403b 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -35,7 +35,7 @@ class SeedPrompt(Seed): # The type of data this prompt represents (e.g., text, image_path, audio_path, video_path) # This field shadows the base class property to allow per-prompt data types - data_type: Optional[PromptDataType] = None + data_type: Optional[PromptDataType] = None # type: ignore[assignment] # Role of the prompt in a conversation (e.g., "user", "assistant") role: Optional[ChatMessageRole] = None @@ -98,13 +98,15 @@ def set_encoding_metadata(self) -> None: if TinyTag.is_supported(self.value): try: tag = TinyTag.get(self.value) + bitrate = int(round(tag.bitrate)) if tag.bitrate is not None else 0 + duration = int(round(tag.duration)) if tag.duration is not None else 0 self.metadata.update( { - "bitrate": int(round(tag.bitrate)), - "samplerate": tag.samplerate, - "bitdepth": tag.bitdepth, - "filesize": tag.filesize, - "duration": int(round(tag.duration)), + "bitrate": bitrate, + "samplerate": tag.samplerate if tag.samplerate is not None else 0, + "bitdepth": tag.bitdepth if tag.bitdepth is not None else 0, + "filesize": tag.filesize if tag.filesize is not None else 0, + "duration": duration, } ) except Exception as ex: diff --git a/pyrit/models/storage_io.py b/pyrit/models/storage_io.py index e69306c07a..05d4f8a5e8 100644 --- a/pyrit/models/storage_io.py +++ b/pyrit/models/storage_io.py @@ -182,15 +182,18 @@ def __init__( self._container_url: str = container_url self._sas_token = sas_token - self._client_async: AsyncContainerClient = None + self._client_async: AsyncContainerClient | None = None - async def _create_container_client_async(self) -> None: + async def _create_container_client_async(self) -> AsyncContainerClient: """ Create an asynchronous ContainerClient for Azure Storage. If a SAS token is provided via the AZURE_STORAGE_ACCOUNT_SAS_TOKEN environment variable or the init sas_token parameter, it will be used for authentication. Otherwise, a delegation SAS token will be created using Entra ID authentication. + + Returns: + AsyncContainerClient: The initialized container client. """ sas_token = self._sas_token if not self._sas_token: @@ -201,6 +204,7 @@ async def _create_container_client_async(self) -> None: container_url=self._container_url, credential=sas_token, ) + return self._client_async async def _upload_blob_async(self, file_name: str, data: bytes, content_type: str) -> None: """ @@ -211,11 +215,15 @@ async def _upload_blob_async(self, file_name: str, data: bytes, content_type: st data (bytes): Byte representation of content to upload to container. content_type (str): Content type to upload. + Raises: + RuntimeError: If the Azure container client is not initialized. """ content_settings = ContentSettings(content_type=f"{content_type}") # type: ignore[no-untyped-call, unused-ignore] logger.info(msg="\nUploading to Azure Storage as blob:\n\t" + file_name) try: + if self._client_async is None: + raise RuntimeError("Azure container client not initialized") await self._client_async.upload_blob( name=file_name, data=data, @@ -297,10 +305,6 @@ async def read_file(self, path: Union[Path, str]) -> bytes: Returns: bytes: The content of the file (blob) as bytes. - Raises: - Exception: If there is an error in reading the blob file, an exception will be logged - and re-raised. - Example: file_content = await read_file("https://account.blob.core.windows.net/container/dir2/1726627689003831.png") @@ -309,7 +313,7 @@ async def read_file(self, path: Union[Path, str]) -> bytes: """ if not self._client_async: - await self._create_container_client_async() + self._client_async = await self._create_container_client_async() blob_name = self._resolve_blob_name(path) @@ -318,7 +322,7 @@ async def read_file(self, path: Union[Path, str]) -> bytes: # Download the blob blob_stream = await blob_client.download_blob() - return await blob_stream.readall() + return bytes(await blob_stream.readall()) except Exception as exc: logger.exception(f"Failed to read file at {blob_name}: {exc}") @@ -337,10 +341,9 @@ async def write_file(self, path: Union[Path, str], data: bytes) -> None: Args: path (Union[Path, str]): Full blob URL or relative blob path. data (bytes): The data to write. - """ if not self._client_async: - await self._create_container_client_async() + self._client_async = await self._create_container_client_async() blob_name = self._resolve_blob_name(path) try: await self._upload_blob_async(file_name=blob_name, data=data, content_type=self._blob_content_type) @@ -360,10 +363,9 @@ async def path_exists(self, path: Union[Path, str]) -> bool: Returns: bool: True when the path exists. - """ if not self._client_async: - await self._create_container_client_async() + self._client_async = await self._create_container_client_async() try: blob_name = self._resolve_blob_name(path) blob_client = self._client_async.get_blob_client(blob=blob_name) @@ -384,15 +386,14 @@ async def is_file(self, path: Union[Path, str]) -> bool: Returns: bool: True when the blob exists and has non-zero content size. - """ if not self._client_async: - await self._create_container_client_async() + self._client_async = await self._create_container_client_async() try: blob_name = self._resolve_blob_name(path) blob_client = self._client_async.get_blob_client(blob=blob_name) blob_properties = await blob_client.get_blob_properties() - return blob_properties.size > 0 + return bool(blob_properties.size > 0) except ResourceNotFoundError: return False finally: diff --git a/pyrit/prompt_converter/add_image_text_converter.py b/pyrit/prompt_converter/add_image_text_converter.py index 8cbf4d8671..b5fb4db89f 100644 --- a/pyrit/prompt_converter/add_image_text_converter.py +++ b/pyrit/prompt_converter/add_image_text_converter.py @@ -168,7 +168,7 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text updated_img = self._add_text_to_image(text=prompt) image_bytes = BytesIO() - mime_type = img_serializer.get_mime_type(self._img_to_add) + mime_type = img_serializer.get_mime_type(self._img_to_add) or "image/png" image_type = mime_type.split("/")[-1] updated_img.save(image_bytes, format=image_type) image_str = base64.b64encode(image_bytes.getvalue()) diff --git a/pyrit/prompt_converter/add_image_to_video_converter.py b/pyrit/prompt_converter/add_image_to_video_converter.py index 5f1d2971c6..7b4b109b30 100644 --- a/pyrit/prompt_converter/add_image_to_video_converter.py +++ b/pyrit/prompt_converter/add_image_to_video_converter.py @@ -142,8 +142,10 @@ async def _add_image_to_video(self, image_path: str, output_path: str) -> str: input_image_bytes = await input_image_data.read_data() image_np_arr = np.frombuffer(input_image_bytes, np.uint8) - overlay = cv2.imdecode(image_np_arr, cv2.IMREAD_UNCHANGED) - overlay = cv2.resize(overlay, self._img_resize_size) + decoded = cv2.imdecode(image_np_arr, cv2.IMREAD_UNCHANGED) + if decoded is None: + raise ValueError("Failed to decode overlay image") + overlay = cv2.resize(decoded, self._img_resize_size) # Get overlay image dimensions image_height, image_width, _ = overlay.shape diff --git a/pyrit/prompt_converter/add_text_image_converter.py b/pyrit/prompt_converter/add_text_image_converter.py index 91fd265e57..ea3236b403 100644 --- a/pyrit/prompt_converter/add_text_image_converter.py +++ b/pyrit/prompt_converter/add_text_image_converter.py @@ -165,7 +165,7 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "imag updated_img = self._add_text_to_image(image=original_img) image_bytes = BytesIO() - mime_type = img_serializer.get_mime_type(prompt) + mime_type = img_serializer.get_mime_type(prompt) or "image/png" image_type = mime_type.split("/")[-1] updated_img.save(image_bytes, format=image_type) image_str = base64.b64encode(image_bytes.getvalue()).decode("utf-8") diff --git a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py index 7c5fdad176..37ca3f4ec1 100644 --- a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py +++ b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py @@ -181,4 +181,4 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text except Exception as e: logger.error("Failed to convert prompt to audio: %s", str(e)) raise - return ConverterResult(output_text=audio_serializer_file, output_type="audio_path") + return ConverterResult(output_text=audio_serializer_file or "", output_type="audio_path") diff --git a/pyrit/prompt_converter/codechameleon_converter.py b/pyrit/prompt_converter/codechameleon_converter.py index 99454946df..43bb1e4ab8 100644 --- a/pyrit/prompt_converter/codechameleon_converter.py +++ b/pyrit/prompt_converter/codechameleon_converter.py @@ -132,7 +132,7 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text if not self.input_supported(input_type): raise ValueError("Input type not supported") - encoded_prompt = str(self.encrypt_function(prompt)) if self.encrypt_function else prompt + encoded_prompt = str(self.encrypt_function(prompt)) if self.encrypt_function is not None else prompt seed_prompt = SeedPrompt.from_yaml_file( pathlib.Path(CONVERTER_SEED_PROMPT_PATH) / "codechameleon_converter.yaml" diff --git a/pyrit/prompt_converter/denylist_converter.py b/pyrit/prompt_converter/denylist_converter.py index a9672e3718..46f427caef 100644 --- a/pyrit/prompt_converter/denylist_converter.py +++ b/pyrit/prompt_converter/denylist_converter.py @@ -28,7 +28,7 @@ def __init__( *, converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] system_prompt_template: Optional[SeedPrompt] = None, - denylist: list[str] = None, + denylist: list[str] | None = None, ): """ Initialize the converter with a target, an optional system prompt template, and a denylist. diff --git a/pyrit/prompt_converter/template_segment_converter.py b/pyrit/prompt_converter/template_segment_converter.py index 8520436471..07ab83e164 100644 --- a/pyrit/prompt_converter/template_segment_converter.py +++ b/pyrit/prompt_converter/template_segment_converter.py @@ -51,18 +51,18 @@ def __init__( ) ) - self._number_parameters = len(self.prompt_template.parameters) + self._number_parameters = len(self.prompt_template.parameters or []) if self._number_parameters < 2: raise ValueError( - f"Template must have at least two parameters, but found {len(self.prompt_template.parameters)}. " + f"Template must have at least two parameters, but found {len(self.prompt_template.parameters or [])}. " f"Template parameters: {self.prompt_template.parameters}" ) # Validate all parameters exist in the template value by attempting to render with empty values try: # Create a dict with empty values for all parameters - empty_values = dict.fromkeys(self.prompt_template.parameters, "") + empty_values = dict.fromkeys(self.prompt_template.parameters or [], "") # This will raise ValueError if any parameter is missing self.prompt_template.render_template_value(**empty_values) except ValueError as e: @@ -107,7 +107,7 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text segments = self._split_prompt_into_segments(prompt) filled_template = self.prompt_template.render_template_value( - **dict(zip(self.prompt_template.parameters, segments, strict=False)) + **dict(zip(self.prompt_template.parameters or [], segments, strict=False)) ) return ConverterResult(output_text=filled_template, output_type="text") diff --git a/pyrit/prompt_normalizer/normalizer_request.py b/pyrit/prompt_normalizer/normalizer_request.py index 30869a09b2..1cfaf97f37 100644 --- a/pyrit/prompt_normalizer/normalizer_request.py +++ b/pyrit/prompt_normalizer/normalizer_request.py @@ -25,8 +25,8 @@ def __init__( self, *, message: Message, - request_converter_configurations: list[PromptConverterConfiguration] = None, - response_converter_configurations: list[PromptConverterConfiguration] = None, + request_converter_configurations: list[PromptConverterConfiguration] | None = None, + response_converter_configurations: list[PromptConverterConfiguration] | None = None, conversation_id: Optional[str] = None, ): """ diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index b730a58669..7407cd8498 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -32,7 +32,19 @@ class PromptNormalizer: Handles normalization and processing of prompts before they are sent to targets. """ - _memory: MemoryInterface = None + _memory: MemoryInterface | None = None + + @property + def memory(self) -> MemoryInterface: + """ + Get the memory instance. + + Raises: + RuntimeError: If memory is not initialized. + """ + if self._memory is None: + raise RuntimeError("Memory is not initialized") + return self._memory def __init__(self, start_token: str = "⟪", end_token: str = "⟫") -> None: """ @@ -74,6 +86,7 @@ async def send_prompt_async( Raises: Exception: If an error occurs during the request processing. ValueError: If the message pieces are not part of the same sequence. + EmptyResponseException: If the target returns no valid responses. Returns: Message: The response received from the target. @@ -105,10 +118,10 @@ async def send_prompt_async( try: responses = await target.send_prompt_async(message=request) - self._memory.add_message_to_memory(request=request) + self.memory.add_message_to_memory(request=request) except EmptyResponseException: # Empty responses are retried, but we don't want them to stop execution - self._memory.add_message_to_memory(request=request) + self.memory.add_message_to_memory(request=request) responses = [ construct_response_from_request( @@ -121,7 +134,7 @@ async def send_prompt_async( except Exception as ex: # Ensure request to memory before processing exception - self._memory.add_message_to_memory(request=request) + self.memory.add_message_to_memory(request=request) error_response = construct_response_from_request( request=request.message_pieces[0], @@ -131,13 +144,19 @@ async def send_prompt_async( ) await self._calc_hash(request=error_response) - self._memory.add_message_to_memory(request=error_response) + self.memory.add_message_to_memory(request=error_response) cid = request.message_pieces[0].conversation_id if request and request.message_pieces else None raise Exception(f"Error sending prompt with conversation ID: {cid}") from ex # handling empty responses message list and None responses if not responses or not any(responses): - return None + # An empty list is valid for write-only targets (e.g., TextTarget) + # that don't produce responses. Return the request as-is. + if responses is not None and len(responses) == 0: + await self._calc_hash(request=request) + self.memory.add_message_to_memory(request=request) + return request + raise EmptyResponseException(message="Target returned no valid responses") # Process all response messages (targets return list[Message]) # Only apply response converters to the last message (final response) @@ -147,7 +166,7 @@ async def send_prompt_async( if is_last: await self.convert_values(converter_configurations=response_converter_configurations, message=resp) await self._calc_hash(request=resp) - self._memory.add_message_to_memory(request=resp) + self.memory.add_message_to_memory(request=resp) # Return the last response for backward compatibility return responses[-1] @@ -312,6 +331,6 @@ async def add_prepended_conversation_to_memory( # and if not, this won't hurt anything piece.id = uuid4() - self._memory.add_message_to_memory(request=request) + self.memory.add_message_to_memory(request=request) return prepended_conversation diff --git a/pyrit/prompt_target/azure_blob_storage_target.py b/pyrit/prompt_target/azure_blob_storage_target.py index dcabad099c..e75f27d6b9 100644 --- a/pyrit/prompt_target/azure_blob_storage_target.py +++ b/pyrit/prompt_target/azure_blob_storage_target.py @@ -151,6 +151,9 @@ async def _upload_blob_async(self, file_name: str, data: bytes, content_type: st file_name (str): File name to assign to uploaded blob. data (bytes): Byte representation of content to upload to container. content_type (str): Content type to upload. + + Raises: + RuntimeError: If blob storage client is not initialized. """ content_settings = ContentSettings(content_type=f"{content_type}") # type: ignore[no-untyped-call, unused-ignore] logger.info(msg="\nUploading to Azure Storage as blob:\n\t" + file_name) @@ -163,6 +166,8 @@ async def _upload_blob_async(self, file_name: str, data: bytes, content_type: st # If not, the file will be put in the root of the container. blob_path = f"{blob_prefix}/{file_name}" if blob_prefix else file_name try: + if self._client_async is None: + raise RuntimeError("Blob storage client not initialized") blob_client = self._client_async.get_blob_client(blob=blob_path) if await blob_client.exists(): logger.info(msg=f"Blob {blob_path} already exists. Deleting it before uploading a new version.") diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index ae02026004..eea7396d94 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -42,9 +42,9 @@ class HuggingFaceChatTarget(PromptChatTarget): ) # Class-level cache for model and tokenizer - _cached_model = None - _cached_tokenizer = None - _cached_model_id = None + _cached_model: Any = None + _cached_tokenizer: Any = None + _cached_model_id: str | None = None # Class-level flag to enable or disable cache _cache_enabled = True @@ -186,9 +186,7 @@ def _load_from_path(self, path: str, **kwargs: Any) -> None: **kwargs: Additional keyword arguments to pass to the model loader. """ logger.info(f"Loading model and tokenizer from path: {path}...") - self.tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call, unused-ignore] - path, trust_remote_code=self.trust_remote_code - ) + self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=self.trust_remote_code) self.model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=self.trust_remote_code, **kwargs) def is_model_id_valid(self) -> bool: @@ -200,7 +198,7 @@ def is_model_id_valid(self) -> bool: """ try: # Attempt to load the configuration of the model - PretrainedConfig.from_pretrained(self.model_id) + PretrainedConfig.from_pretrained(self.model_id or "") return True except Exception as e: logger.error(f"Invalid HuggingFace model ID {self.model_id}: {e}") @@ -248,27 +246,27 @@ async def load_model_and_tokenizer(self) -> None: ".cache", "huggingface", "hub", - f"models--{self.model_id.replace('/', '--')}", + f"models--{(self.model_id or '').replace('/', '--')}", ) if self.necessary_files is None: # Download all files if no specific files are provided logger.info(f"Downloading all files for {self.model_id}...") - await download_specific_files(self.model_id, None, self.huggingface_token, Path(cache_dir)) + await download_specific_files(self.model_id or "", None, self.huggingface_token, Path(cache_dir)) else: # Download only the necessary files logger.info(f"Downloading specific files for {self.model_id}...") await download_specific_files( - self.model_id, self.necessary_files, self.huggingface_token, Path(cache_dir) + self.model_id or "", self.necessary_files, self.huggingface_token, Path(cache_dir) ) # Load the tokenizer and model from the specified directory logger.info(f"Loading model {self.model_id} from cache path: {cache_dir}...") - self.tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call, unused-ignore] - self.model_id, cache_dir=cache_dir, trust_remote_code=self.trust_remote_code + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_id or "", cache_dir=cache_dir, trust_remote_code=self.trust_remote_code ) self.model = AutoModelForCausalLM.from_pretrained( - self.model_id, + self.model_id or "", cache_dir=cache_dir, trust_remote_code=self.trust_remote_code, **optional_model_kwargs, @@ -363,7 +361,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: response = construct_response_from_request( request=request, response_text_pieces=[assistant_response], - prompt_metadata={"model_id": model_identifier}, + prompt_metadata={"model_id": model_identifier or ""}, ) return [response] diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index 514f0a7c27..f6568529fd 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -265,7 +265,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: # Use unified error handling - automatically detects ChatCompletion and validates response = await self._handle_openai_request( - api_call=lambda: self._async_client.chat.completions.create(**body), + api_call=lambda: self._client.chat.completions.create(**body), request=message, ) return [response] diff --git a/pyrit/prompt_target/openai/openai_completion_target.py b/pyrit/prompt_target/openai/openai_completion_target.py index 14a39f29a6..39a11944e2 100644 --- a/pyrit/prompt_target/openai/openai_completion_target.py +++ b/pyrit/prompt_target/openai/openai_completion_target.py @@ -153,7 +153,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: # Use unified error handler - automatically detects Completion and validates response = await self._handle_openai_request( - api_call=lambda: self._async_client.completions.create(**request_params), # type: ignore[call-overload] + api_call=lambda: self._client.completions.create(**request_params), # type: ignore[call-overload] request=message, ) return [response] diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index 2009c5ef7c..2cce81e71f 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -205,7 +205,7 @@ async def _send_generate_request_async(self, message: Message) -> Message: # Use unified error handler for consistent error handling return await self._handle_openai_request( - api_call=lambda: self._async_client.images.generate(**image_generation_args), + api_call=lambda: self._client.images.generate(**image_generation_args), request=message, ) @@ -255,7 +255,7 @@ async def _send_edit_request_async(self, message: Message) -> Message: image_edit_args["style"] = self.style return await self._handle_openai_request( - api_call=lambda: self._async_client.images.edit(**image_edit_args), + api_call=lambda: self._client.images.edit(**image_edit_args), request=message, ) diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 2948988c63..be6a01abc2 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -559,7 +559,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: # Use unified error handling - automatically detects Response and validates result = await self._handle_openai_request( - api_call=lambda body=body: self._async_client.responses.create(**body), + api_call=lambda body=body: self._client.responses.create(**body), request=message, ) diff --git a/pyrit/prompt_target/openai/openai_target.py b/pyrit/prompt_target/openai/openai_target.py index 4788fbaf37..8058a2b7fd 100644 --- a/pyrit/prompt_target/openai/openai_target.py +++ b/pyrit/prompt_target/openai/openai_target.py @@ -63,6 +63,18 @@ class OpenAITarget(PromptTarget): _async_client: Optional[AsyncOpenAI] = None + @property + def _client(self) -> AsyncOpenAI: + """ + Non-None accessor for the async client, used by subclasses. + + Raises: + RuntimeError: If the AsyncOpenAI client is not initialized. + """ + if self._async_client is None: + raise RuntimeError("AsyncOpenAI client is not initialized") + return self._async_client + def __init__( self, *, @@ -419,6 +431,7 @@ async def _handle_openai_request( APITimeoutError: For transient infrastructure errors. APIConnectionError: For transient infrastructure errors. AuthenticationError: For authentication failures. + ValueError: If there are no message pieces in the request. """ try: # Execute the API call @@ -426,6 +439,8 @@ async def _handle_openai_request( # Extract MessagePiece for validation and construction (most targets use single piece) request_piece = request.message_pieces[0] if request.message_pieces else None + if request_piece is None: + raise ValueError("No message pieces in request") # Check for content filter via subclass implementation if self._check_content_filter(response): @@ -452,6 +467,8 @@ def model_dump_json(self) -> str: return error_str request_piece = request.message_pieces[0] if request.message_pieces else None + if request_piece is None: + raise ValueError("No message pieces in request") from e return self._handle_content_filter_response(_ErrorResponse(), request_piece) except BadRequestError as e: # Handle 400 errors - includes input policy filters and some Azure output-filter 400s @@ -470,6 +487,8 @@ def model_dump_json(self) -> str: ) request_piece = request.message_pieces[0] if request.message_pieces else None + if request_piece is None: + raise ValueError("No message pieces in request") from e return handle_bad_request_exception( response_text=str(payload), request=request_piece, @@ -583,7 +602,7 @@ def _set_openai_env_configuration_vars(self) -> None: raise NotImplementedError def _warn_url_with_api_path( - self, endpoint_url: str, api_path: str, provider_examples: dict[str, str] = None + self, endpoint_url: str, api_path: str, provider_examples: dict[str, str] | None = None ) -> None: """ Warn if URL includes API-specific path that should be handled by the SDK. diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index bde1e2df93..71e45ec1de 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -142,10 +142,10 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: # Use unified error handler for consistent error handling response = await self._handle_openai_request( - api_call=lambda: self._async_client.audio.speech.create( - model=body_parameters["model"], # type: ignore[arg-type] - voice=body_parameters["voice"], # type: ignore[arg-type] - input=body_parameters["input"], # type: ignore[arg-type] + api_call=lambda: self._client.audio.speech.create( + model=str(body_parameters["model"]), + voice=str(body_parameters["voice"]), + input=str(body_parameters["input"]), response_format=body_parameters.get("response_format"), # type: ignore[arg-type] speed=body_parameters.get("speed"), # type: ignore[arg-type] ), diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index ebbee5eeb5..bd33fdf5df 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -207,6 +207,8 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: self._validate_request(message=message) text_piece = message.get_piece_by_type(data_type="text") + if text_piece is None: + raise ValueError("No text piece found in message") # Validate video_path pieces for remix mode (does not strip them) self._validate_video_remix_pieces(message=message) @@ -265,7 +267,7 @@ async def _send_text_plus_image_to_video_async( logger.info("Text+Image-to-video mode: Using image as first frame") input_file = await self._prepare_image_input_async(image_piece=image_piece) return await self._handle_openai_request( - api_call=lambda: self._async_client.videos.create_and_poll( + api_call=lambda: self._client.videos.create_and_poll( model=self._model_name, prompt=prompt, size=self._size, @@ -287,7 +289,7 @@ async def _send_text_to_video_async(self, *, prompt: str, request: Message) -> M The response Message with the generated video path. """ return await self._handle_openai_request( - api_call=lambda: self._async_client.videos.create_and_poll( + api_call=lambda: self._client.videos.create_and_poll( model=self._model_name, prompt=prompt, size=self._size, @@ -343,11 +345,11 @@ async def _remix_and_poll_async(self, *, video_id: str, prompt: str) -> Any: Returns: The completed Video object from the OpenAI SDK. """ - video = await self._async_client.videos.remix(video_id, prompt=prompt) + video = await self._client.videos.remix(video_id, prompt=prompt) # Poll until completion if not already done if video.status not in ["completed", "failed"]: - video = await self._async_client.videos.poll(video.id) + video = await self._client.videos.poll(video.id) return video @@ -397,7 +399,7 @@ async def _construct_message_from_response(self, response: Any, request: Any) -> logger.info(f"Video was remixed from: {video.remixed_from_video_id}") # Download video content using SDK - video_response = await self._async_client.videos.download_content(video.id) + video_response = await self._client.videos.download_content(video.id) # Extract bytes from HttpxBinaryResponseContent video_content = video_response.content diff --git a/pyrit/prompt_target/prompt_shield_target.py b/pyrit/prompt_target/prompt_shield_target.py index 8bb249f350..16f95a8506 100644 --- a/pyrit/prompt_target/prompt_shield_target.py +++ b/pyrit/prompt_target/prompt_shield_target.py @@ -88,10 +88,15 @@ def __init__( this target instance. Defaults to None. custom_capabilities (TargetCapabilities, Optional): **Deprecated.** Use ``custom_configuration`` instead. Will be removed in v0.14.0. + + Raises: + ValueError: If the endpoint value is not provided. """ endpoint_value = default_values.get_required_value( env_var_name=self.ENDPOINT_URI_ENVIRONMENT_VARIABLE, passed_value=endpoint ) + if endpoint_value is None: + raise ValueError("Endpoint value is required") super().__init__( max_requests_per_minute=max_requests_per_minute, endpoint=endpoint_value, @@ -99,12 +104,15 @@ def __init__( custom_capabilities=custom_capabilities, ) - self._api_version = api_version + self._api_version = api_version or "2024-09-01" # API key is required - either from parameter or environment variable - self._api_key = default_values.get_required_value( + _api_key_value = default_values.get_required_value( env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key ) + if _api_key_value is None: + raise ValueError("API key is required") + self._api_key = _api_key_value self._force_entry_field: PromptShieldEntryField = field diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index 4c736daf2d..a07fa75e11 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -84,7 +84,7 @@ def import_scores_from_csv(self, csv_file_path: Path) -> list[MessagePiece]: original_value=row["value"], original_value_data_type=row.get("data_type", None), # type: ignore[arg-type] conversation_id=row.get("conversation_id", None), - sequence=int(sequence_str) if sequence_str else None, + sequence=int(sequence_str) if sequence_str else 0, labels=labels, response_error=row.get("response_error", None), # type: ignore[arg-type] prompt_target_identifier=self.get_identifier(), diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 9b56078272..96aa6d15ad 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -600,7 +600,7 @@ def _validate_request(self, *, message: Message) -> None: if piece_type == "image_path": mime_type = DataTypeSerializer.get_mime_type(piece.converted_value) - if not mime_type.startswith("image/"): + if not mime_type or not mime_type.startswith("image/"): raise ValueError( f"Invalid image format for image_path: {piece.converted_value}. " f"Detected MIME type: {mime_type}." diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/class_registries/initializer_registry.py index 4007c58d66..a8043ea9ae 100644 --- a/pyrit/registry/class_registries/initializer_registry.py +++ b/pyrit/registry/class_registries/initializer_registry.py @@ -76,13 +76,17 @@ def __init__(self, *, discovery_path: Optional[Path] = None, lazy_discovery: boo To discover only scenarios, pass pyrit/setup/initializers/scenarios. lazy_discovery: If True, discovery is deferred until first access. Defaults to False for backwards compatibility. + + Raises: + ValueError: If the discovery path could not be resolved. """ self._discovery_path = discovery_path if self._discovery_path is None: self._discovery_path = Path(PYRIT_PATH) / "setup" / "initializers" # At this point _discovery_path is guaranteed to be a Path - assert self._discovery_path is not None + if self._discovery_path is None: + raise ValueError("self._discovery_path is not initialized") super().__init__(lazy_discovery=lazy_discovery) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 7e91c53bda..ee0faf910a 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -627,7 +627,8 @@ async def _execute_scenario_async(self) -> ScenarioResult: # Type narrowing: _scenario_result_id is guaranteed to be non-None at this point # (verified in run_async before calling this method) - assert self._scenario_result_id is not None + if self._scenario_result_id is None: + raise ValueError("self._scenario_result_id is not initialized") scenario_result_id: str = self._scenario_result_id # Increment number_tries at the start of each run diff --git a/pyrit/scenario/scenarios/airt/content_harms.py b/pyrit/scenario/scenarios/airt/content_harms.py index d22ece85ff..5b87a032b1 100644 --- a/pyrit/scenario/scenarios/airt/content_harms.py +++ b/pyrit/scenario/scenarios/airt/content_harms.py @@ -187,7 +187,7 @@ def _get_default_adversarial_target(self) -> OpenAIChatTarget: endpoint = os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") return OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=1.2, ) diff --git a/pyrit/scenario/scenarios/airt/cyber.py b/pyrit/scenario/scenarios/airt/cyber.py index fe53b62bb9..63578b8bce 100644 --- a/pyrit/scenario/scenarios/airt/cyber.py +++ b/pyrit/scenario/scenarios/airt/cyber.py @@ -159,7 +159,7 @@ def _get_default_objective_scorer(self) -> TrueFalseCompositeScorer: presence_of_malware = SelfAskTrueFalseScorer( chat_target=OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), ), true_false_question_path=SCORER_SEED_PROMPT_PATH / "true_false_question" / "malware.yaml", @@ -169,7 +169,7 @@ def _get_default_objective_scorer(self) -> TrueFalseCompositeScorer: scorer=SelfAskRefusalScorer( chat_target=OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), ) ) @@ -189,7 +189,7 @@ def _get_default_adversarial_target(self) -> OpenAIChatTarget: endpoint = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") return OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=1.2, ) diff --git a/pyrit/scenario/scenarios/airt/jailbreak.py b/pyrit/scenario/scenarios/airt/jailbreak.py index 9667cca820..68aa01aa84 100644 --- a/pyrit/scenario/scenarios/airt/jailbreak.py +++ b/pyrit/scenario/scenarios/airt/jailbreak.py @@ -125,7 +125,7 @@ def __init__( scenario_result_id: Optional[str] = None, num_templates: Optional[int] = None, num_attempts: int = 1, - jailbreak_names: list[str] = None, + jailbreak_names: list[str] | None = None, ) -> None: """ Initialize the jailbreak scenario. @@ -201,7 +201,7 @@ def _create_adversarial_target(self) -> OpenAIChatTarget: endpoint = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") return OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=1.2, ) @@ -298,7 +298,7 @@ async def _get_atomic_attack_from_strategy_async( return AtomicAttack( atomic_attack_name=f"jailbreak_{template_name}", attack_technique=AttackTechnique(attack=attack), - seed_groups=self._seed_groups, + seed_groups=self._seed_groups or [], ) async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: diff --git a/pyrit/scenario/scenarios/airt/leakage.py b/pyrit/scenario/scenarios/airt/leakage.py index 43354690c5..0a35f79eec 100644 --- a/pyrit/scenario/scenarios/airt/leakage.py +++ b/pyrit/scenario/scenarios/airt/leakage.py @@ -187,7 +187,7 @@ def _get_default_objective_scorer(self) -> TrueFalseCompositeScorer: presence_of_leakage = SelfAskTrueFalseScorer( chat_target=OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), ), true_false_question_path=SCORER_SEED_PROMPT_PATH / "true_false_question" / "leakage.yaml", @@ -200,7 +200,7 @@ def _get_default_objective_scorer(self) -> TrueFalseCompositeScorer: scorer=SelfAskRefusalScorer( chat_target=OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), ) ) @@ -220,7 +220,7 @@ def _get_default_adversarial_target(self) -> OpenAIChatTarget: endpoint = os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") return OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=1.2, ) diff --git a/pyrit/scenario/scenarios/airt/psychosocial.py b/pyrit/scenario/scenarios/airt/psychosocial.py index 2caf2d6f78..fd20991441 100644 --- a/pyrit/scenario/scenarios/airt/psychosocial.py +++ b/pyrit/scenario/scenarios/airt/psychosocial.py @@ -292,7 +292,7 @@ def _resolve_seed_groups(self) -> ResolvedSeedData: if harm_category_filter: seed_groups = self._filter_by_harm_category( - seed_groups=seed_groups, + seed_groups=seed_groups or [], harm_category=harm_category_filter, ) logger.info( @@ -358,7 +358,7 @@ def _get_default_adversarial_target(self) -> OpenAIChatTarget: endpoint = os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") return OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=0.7, ) @@ -398,7 +398,7 @@ def _get_scorer(self, subharm: Optional[str] = None) -> FloatScaleThresholdScore endpoint = os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") azure_openai_chat_target = OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), ) @@ -465,7 +465,7 @@ def _create_single_turn_attacks( AtomicAttack( atomic_attack_name="psychosocial_single_turn", attack_technique=AttackTechnique(attack=prompt_sending), - seed_groups=seed_groups, + seed_groups=seed_groups or [], memory_labels=self._memory_labels, ) ) @@ -479,7 +479,7 @@ def _create_single_turn_attacks( AtomicAttack( atomic_attack_name="psychosocial_role_play", attack_technique=AttackTechnique(attack=role_play), - seed_groups=seed_groups, + seed_groups=seed_groups or [], memory_labels=self._memory_labels, ) ) @@ -516,6 +516,6 @@ def _create_multi_turn_attack( return AtomicAttack( atomic_attack_name="psychosocial_crescendo_turn", attack_technique=AttackTechnique(attack=crescendo), - seed_groups=seed_groups, + seed_groups=seed_groups or [], memory_labels=self._memory_labels, ) diff --git a/pyrit/scenario/scenarios/airt/scam.py b/pyrit/scenario/scenarios/airt/scam.py index e2db719f40..c91c4130e3 100644 --- a/pyrit/scenario/scenarios/airt/scam.py +++ b/pyrit/scenario/scenarios/airt/scam.py @@ -187,7 +187,7 @@ def _get_default_objective_scorer(self) -> TrueFalseCompositeScorer: scam_materials = SelfAskTrueFalseScorer( chat_target=OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=0.9, ), @@ -198,7 +198,7 @@ def _get_default_objective_scorer(self) -> TrueFalseCompositeScorer: scorer=SelfAskRefusalScorer( chat_target=OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), ) ) @@ -216,7 +216,7 @@ def _get_default_adversarial_target(self) -> OpenAIChatTarget: endpoint = os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") return OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=1.2, ) @@ -290,7 +290,7 @@ def _get_atomic_attack_from_strategy(self, strategy: str) -> AtomicAttack: return AtomicAttack( atomic_attack_name=f"scam_{strategy}", attack_technique=AttackTechnique(attack=attack_strategy), - seed_groups=self._seed_groups, + seed_groups=self._seed_groups or [], memory_labels=self._memory_labels, ) diff --git a/pyrit/scenario/scenarios/foundry/red_team_agent.py b/pyrit/scenario/scenarios/foundry/red_team_agent.py index 0f2fe297b8..68b7e8b44d 100644 --- a/pyrit/scenario/scenarios/foundry/red_team_agent.py +++ b/pyrit/scenario/scenarios/foundry/red_team_agent.py @@ -312,7 +312,7 @@ def _get_default_adversarial_target(self) -> OpenAIChatTarget: endpoint = os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") return OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=1.2, ) @@ -475,7 +475,7 @@ def _get_attack( # Create the adversarial config from self._adversarial_target attack_adversarial_config = AttackAdversarialConfig(target=self._adversarial_chat) - kwargs["attack_adversarial_config"] = attack_adversarial_config + kwargs["attack_adversarial_config"] = attack_adversarial_config # type: ignore[assignment] # Add attack-specific kwargs if provided if attack_kwargs: diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 16aa3d75ab..8a13f5ad5a 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -118,6 +118,7 @@ def __init__( Raises: ValueError: If no endpoint is provided. + RuntimeError: If the API key is not a string when validation is performed. """ if harm_categories: self._harm_categories = harm_categories @@ -151,6 +152,8 @@ def __init__( self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=credential) else: # String API key + if not isinstance(self._api_key, str): + raise RuntimeError("Expected string API key") self._azure_cf_client = ContentSafetyClient(self._endpoint, AzureKeyCredential(self._api_key)) else: raise ValueError("Please provide the Azure Content Safety endpoint") @@ -180,7 +183,7 @@ async def evaluate_async( file_mapping: Optional["ScorerEvalDatasetFiles"] = None, *, num_scorer_trials: int = 3, - update_registry_behavior: "RegistryUpdateBehavior" = None, + update_registry_behavior: "RegistryUpdateBehavior | None" = None, max_concurrency: int = 10, ) -> Optional["ScorerMetrics"]: """ diff --git a/pyrit/score/float_scale/float_scale_scorer.py b/pyrit/score/float_scale/float_scale_scorer.py index af39cf5bec..a117034b3b 100644 --- a/pyrit/score/float_scale/float_scale_scorer.py +++ b/pyrit/score/float_scale/float_scale_scorer.py @@ -58,8 +58,12 @@ def get_scorer_metrics(self) -> Optional["HarmScorerMetrics"]: if self.evaluation_file_mapping is None or self.evaluation_file_mapping.harm_category is None: return None + eval_hash = self.get_identifier().eval_hash + if eval_hash is None: + return None + return find_harm_metrics_by_eval_hash( - eval_hash=self.get_identifier().eval_hash, + eval_hash=eval_hash, harm_category=self.evaluation_file_mapping.harm_category, ) diff --git a/pyrit/score/printer/console_scorer_printer.py b/pyrit/score/printer/console_scorer_printer.py index c8270a10a9..0c5772f58a 100644 --- a/pyrit/score/printer/console_scorer_printer.py +++ b/pyrit/score/printer/console_scorer_printer.py @@ -77,16 +77,16 @@ def _get_quality_color( """ if higher_is_better: if value >= good_threshold: - return Fore.GREEN # type: ignore[no-any-return] + return str(Fore.GREEN) if value < bad_threshold: - return Fore.RED # type: ignore[no-any-return] - return Fore.CYAN # type: ignore[no-any-return] + return str(Fore.RED) + return str(Fore.CYAN) # Lower is better (e.g., MAE, score time) if value <= good_threshold: - return Fore.GREEN # type: ignore[no-any-return] + return str(Fore.GREEN) if value > bad_threshold: - return Fore.RED # type: ignore[no-any-return] - return Fore.CYAN # type: ignore[no-any-return] + return str(Fore.RED) + return str(Fore.CYAN) def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: """ diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index b18a1802a9..11308edb64 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -268,7 +268,7 @@ async def evaluate_async( file_mapping: Optional[ScorerEvalDatasetFiles] = None, *, num_scorer_trials: int = 3, - update_registry_behavior: RegistryUpdateBehavior = None, + update_registry_behavior: RegistryUpdateBehavior | None = None, max_concurrency: int = 10, ) -> Optional[ScorerMetrics]: """ @@ -355,7 +355,7 @@ async def score_text_async(self, text: str, *, objective: Optional[str] = None) ] ) - request.message_pieces[0].id = None + request.message_pieces[0].id = None # type: ignore[assignment] return await self.score_async(request, objective=objective) async def score_image_async(self, image_path: str, *, objective: Optional[str] = None) -> list[Score]: @@ -379,7 +379,7 @@ async def score_image_async(self, image_path: str, *, objective: Optional[str] = ] ) - request.message_pieces[0].id = None + request.message_pieces[0].id = None # type: ignore[assignment] return await self.score_async(request, objective=objective) async def score_prompts_batch_async( diff --git a/pyrit/score/scorer_evaluation/scorer_evaluator.py b/pyrit/score/scorer_evaluation/scorer_evaluator.py index f0541b7e23..be931d0b01 100644 --- a/pyrit/score/scorer_evaluation/scorer_evaluator.py +++ b/pyrit/score/scorer_evaluation/scorer_evaluator.py @@ -295,6 +295,10 @@ def _should_skip_evaluation( try: scorer_hash = self.scorer.get_identifier().eval_hash + if scorer_hash is None: + logger.debug("No eval_hash available for scorer, cannot check existing metrics") + return (False, None) + # Determine if this is a harm or objective evaluation metrics_type = MetricsType.OBJECTIVE if isinstance(self.scorer, TrueFalseScorer) else MetricsType.HARM @@ -504,10 +508,14 @@ def _write_metrics_to_registry( result_file_path (Path): The full path to the result file. """ try: + eval_hash = self.scorer.get_identifier().eval_hash + if eval_hash is None: + logger.warning("Cannot write metrics: no eval_hash available for scorer") + return replace_evaluation_results( file_path=result_file_path, scorer_identifier=self.scorer.get_identifier(), - eval_hash=self.scorer.get_identifier().eval_hash, + eval_hash=eval_hash, metrics=metrics, ) except Exception as e: diff --git a/pyrit/score/true_false/prompt_shield_scorer.py b/pyrit/score/true_false/prompt_shield_scorer.py index 652623d8cd..5b3f067bb3 100644 --- a/pyrit/score/true_false/prompt_shield_scorer.py +++ b/pyrit/score/true_false/prompt_shield_scorer.py @@ -119,17 +119,16 @@ def _parse_response_to_boolean_list(self, response: str) -> list[bool]: """ response_json: dict[str, Any] = json.loads(response) - user_detections = [] - document_detections = [] - user_prompt_attack: dict[str, bool] = response_json.get("userPromptAnalysis", False) documents_attack: list[dict[str, Any]] = response_json.get("documentsAnalysis", False) - user_detections = [False] if not user_prompt_attack else [user_prompt_attack.get("attackDetected")] + user_detections: list[bool] = ( + [False] if not user_prompt_attack else [bool(user_prompt_attack.get("attackDetected"))] + ) if not documents_attack: - document_detections = [False] + document_detections: list[bool] = [False] else: - document_detections = [document.get("attackDetected") for document in documents_attack] + document_detections = [bool(document.get("attackDetected")) for document in documents_attack] return user_detections + document_detections diff --git a/pyrit/score/true_false/self_ask_true_false_scorer.py b/pyrit/score/true_false/self_ask_true_false_scorer.py index da1054274d..d79060fcb4 100644 --- a/pyrit/score/true_false/self_ask_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_true_false_scorer.py @@ -140,6 +140,8 @@ def __init__( if true_false_question_path: true_false_question_path = verify_and_resolve_path(true_false_question_path) true_false_question = yaml.safe_load(true_false_question_path.read_text(encoding="utf-8")) + if true_false_question is None: + raise ValueError("Failed to load true_false_question YAML") for key in ["category", "true_description", "false_description"]: if key not in true_false_question: diff --git a/pyrit/score/true_false/true_false_composite_scorer.py b/pyrit/score/true_false/true_false_composite_scorer.py index c66c24d437..45d0dc4cdb 100644 --- a/pyrit/score/true_false/true_false_composite_scorer.py +++ b/pyrit/score/true_false/true_false_composite_scorer.py @@ -113,7 +113,8 @@ async def _score_async( # Ensure the message piece has an ID piece_id = message.message_pieces[0].id - assert piece_id is not None, "Message piece must have an ID" + if piece_id is None: + raise ValueError("Message piece must have an ID") return_score = Score( score_value=str(result.value), diff --git a/pyrit/score/true_false/true_false_scorer.py b/pyrit/score/true_false/true_false_scorer.py index 9074b79170..b0c90c0737 100644 --- a/pyrit/score/true_false/true_false_scorer.py +++ b/pyrit/score/true_false/true_false_scorer.py @@ -94,7 +94,11 @@ def get_scorer_metrics(self) -> Optional["ObjectiveScorerMetrics"]: if not result_file.exists(): return None - return find_objective_metrics_by_eval_hash(eval_hash=self.get_identifier().eval_hash, file_path=result_file) + eval_hash = self.get_identifier().eval_hash + if eval_hash is None: + return None + + return find_objective_metrics_by_eval_hash(eval_hash=eval_hash, file_path=result_file) async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: """ diff --git a/pyrit/setup/initializers/airt.py b/pyrit/setup/initializers/airt.py index c899797b92..a0e61c52d4 100644 --- a/pyrit/setup/initializers/airt.py +++ b/pyrit/setup/initializers/airt.py @@ -110,6 +110,9 @@ async def initialize_async(self) -> None: 2. Composite harm and objective scorers 3. Adversarial target configurations 4. Default values for all attack types + + Raises: + ValueError: If required environment variables are not set. """ # Ensure operator, operation, and email are populated from GLOBAL_MEMORY_LABELS. self._validate_operation_fields() @@ -121,8 +124,10 @@ async def initialize_async(self) -> None: scorer_model_name = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2") # Type assertions - safe because validate() already checked these - assert converter_endpoint is not None - assert scorer_endpoint is not None + if converter_endpoint is None: + raise ValueError("converter_endpoint is not initialized") + if scorer_endpoint is None: + raise ValueError("scorer_endpoint is not initialized") # model name can be empty in certain cases (e.g., custom model deployments that don't need model name) # Check for API keys first, fall back to Entra auth if not set @@ -137,7 +142,7 @@ async def initialize_async(self) -> None: # 1. Setup converter target self._setup_converter_target( - endpoint=converter_endpoint, api_key=converter_api_key, model_name=converter_model_name + endpoint=converter_endpoint, api_key=converter_api_key, model_name=converter_model_name or "" ) # 2. Setup scorers @@ -145,12 +150,12 @@ async def initialize_async(self) -> None: endpoint=scorer_endpoint, api_key=scorer_api_key, content_safety_api_key=content_safety_api_key, - model_name=scorer_model_name, + model_name=scorer_model_name or "", ) # 3. Setup adversarial targets self._setup_adversarial_targets( - endpoint=converter_endpoint, api_key=converter_api_key, model_name=converter_model_name + endpoint=converter_endpoint, api_key=converter_api_key, model_name=converter_model_name or "" ) def _setup_converter_target(self, *, endpoint: str, api_key: str, model_name: str) -> None: diff --git a/pyrit/show_versions.py b/pyrit/show_versions.py index e19fde71ff..301faebdd7 100644 --- a/pyrit/show_versions.py +++ b/pyrit/show_versions.py @@ -56,7 +56,7 @@ def _get_deps_info() -> dict[str, str | None]: from pyrit import __version__ - deps_info = {"pyrit": __version__} + deps_info: dict[str, str | None] = {"pyrit": __version__} from importlib.metadata import PackageNotFoundError, version @@ -78,5 +78,5 @@ def show_versions() -> None: print(f"{k:>10}: {stat}") print("\nPython dependencies:") - for k, stat in deps_info.items(): - print(f"{k:>13}: {stat}") + for k, stat_or_none in deps_info.items(): + print(f"{k:>13}: {stat_or_none}") diff --git a/tests/unit/datasets/test_harmbench_multimodal_dataset.py b/tests/unit/datasets/test_harmbench_multimodal_dataset.py index e7a7784de6..23593ef973 100644 --- a/tests/unit/datasets/test_harmbench_multimodal_dataset.py +++ b/tests/unit/datasets/test_harmbench_multimodal_dataset.py @@ -144,3 +144,23 @@ def test_init_rejects_raw_string_matching_enum_value_for_categories(): """Test that raw strings matching enum values are rejected.""" with pytest.raises(ValueError, match="Expected SemanticCategory"): _HarmBenchMultimodalDataset(categories=["illegal"]) + + +@pytest.mark.asyncio +async def test_fetch_and_save_image_raises_when_memory_not_configured(): + """Test that _fetch_and_save_image_async raises RuntimeError when serializer memory is not configured.""" + from unittest.mock import MagicMock + + mock_serializer = MagicMock() + mock_memory = MagicMock() + mock_memory.results_path = None + mock_memory.results_storage_io = None + mock_serializer._memory = mock_memory + + with patch( + "pyrit.datasets.seed_datasets.remote.harmbench_multimodal_dataset.data_serializer_factory", + return_value=mock_serializer, + ): + loader = _HarmBenchMultimodalDataset() + with pytest.raises(RuntimeError, match="Serializer memory is not properly configured"): + await loader._fetch_and_save_image_async(behavior_id="test_id", image_url="https://example.com/img.png") diff --git a/tests/unit/datasets/test_vlsu_multimodal_dataset.py b/tests/unit/datasets/test_vlsu_multimodal_dataset.py index b88c33dbb0..606a9c4c47 100644 --- a/tests/unit/datasets/test_vlsu_multimodal_dataset.py +++ b/tests/unit/datasets/test_vlsu_multimodal_dataset.py @@ -377,3 +377,23 @@ async def test_both_prompts_use_combined_category(self): # Both should use combined_category, not their individual categories for seed in dataset.seeds: assert seed.harm_categories == ["C1: Slurs, Hate Speech, Hate Symbols"] + + +@pytest.mark.asyncio +async def test_fetch_and_save_image_raises_when_memory_not_configured(): + """Test that _fetch_and_save_image_async raises RuntimeError when serializer memory is not configured.""" + from unittest.mock import MagicMock + + mock_serializer = MagicMock() + mock_memory = MagicMock() + mock_memory.results_path = None + mock_memory.results_storage_io = None + mock_serializer._memory = mock_memory + + with patch( + "pyrit.datasets.seed_datasets.remote.vlsu_multimodal_dataset.data_serializer_factory", + return_value=mock_serializer, + ): + loader = _VLSUMultimodalDataset() + with pytest.raises(RuntimeError, match="Serializer memory is not properly configured"): + await loader._fetch_and_save_image_async(group_id="test_group", image_url="https://example.com/img.png") diff --git a/tests/unit/embedding/test_azure_text_embedding.py b/tests/unit/embedding/test_azure_text_embedding.py index 2716376dbf..8fb3400412 100644 --- a/tests/unit/embedding/test_azure_text_embedding.py +++ b/tests/unit/embedding/test_azure_text_embedding.py @@ -95,10 +95,9 @@ def mock_token_provider(): # Create instance with token provider embedding = OpenAITextEmbedding(api_key=mock_token_provider) - # Verify async client was created with the callable + # Verify async client was created with a callable (ensure_async_token_provider wraps sync→async) async_call_args = mock_async_openai.call_args assert callable(async_call_args.kwargs["api_key"]) - assert async_call_args.kwargs["api_key"]() == "mock-token" assert async_call_args.kwargs["base_url"] == "https://mock.azure.com/" assert embedding._async_client == mock_async_client diff --git a/tests/unit/executor/attack/core/test_attack_parameters.py b/tests/unit/executor/attack/core/test_attack_parameters.py index 47f853328e..ec667a7798 100644 --- a/tests/unit/executor/attack/core/test_attack_parameters.py +++ b/tests/unit/executor/attack/core/test_attack_parameters.py @@ -307,3 +307,15 @@ async def test_excluded_class_rejects_excluded_field_overrides(self) -> None: seed_group=seed_group, next_message=_make_message("user", "Should fail"), ) + + +@pytest.mark.asyncio +async def test_from_seed_group_async_raises_when_objective_is_none(): + """Test that from_seed_group_async raises ValueError when seed_group.objective is None.""" + seed_group = MagicMock(spec=SeedAttackGroup) + seed_group.validate = MagicMock() + seed_group.objective = None + seed_group.simulated_conversation = None + + with pytest.raises(ValueError, match="seed_group.objective is not initialized"): + await AttackParameters.from_seed_group_async(seed_group=seed_group) diff --git a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py index 2ea2e5f40a..41e8b04c53 100644 --- a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py @@ -1790,3 +1790,17 @@ def test_add_adversarial_chat_conversation_id_ensures_uniqueness(self, basic_att ) in context.related_conversations ) + + +def test_tap_init_raises_when_objective_scorer_is_none(): + """Test that TAP __init__ raises ValueError when AttackScoringConfig has objective_scorer=None.""" + scoring_config = AttackScoringConfig(objective_scorer=None) + with pytest.raises(ValueError, match="objective_scorer is required"): + TreeOfAttacksWithPruningAttack( + objective_target=MagicMock(spec=PromptChatTarget), + attack_adversarial_config=MagicMock( + target=MagicMock(spec=PromptChatTarget), + system_prompt_path=None, + ), + attack_scoring_config=scoring_config, + ) \ No newline at end of file diff --git a/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py b/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py index 23a348ff40..3b16e1968e 100644 --- a/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py +++ b/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py @@ -488,3 +488,22 @@ def test_prompt_node_multi_level_hierarchy(self) -> None: assert len(root.children) == 1 assert len(level1.children) == 1 assert len(level2.children) == 0 + + +def test_create_normalizer_requests_raises_when_seed_group_message_none(): + """Test that _create_normalizer_requests raises ValueError when seed_group.next_message is None.""" + from unittest.mock import PropertyMock + + from pyrit.executor.promptgen.fuzzer.fuzzer import FuzzerGenerator + + generator = FuzzerGenerator.__new__(FuzzerGenerator) + generator._request_converters = [] + generator._response_converters = [] + + with patch("pyrit.executor.promptgen.fuzzer.fuzzer.SeedGroup") as MockSeedGroup: + mock_instance = MagicMock() + type(mock_instance).next_message = PropertyMock(return_value=None) + MockSeedGroup.return_value = mock_instance + + with pytest.raises(ValueError, match="No message in seed group"): + generator._create_normalizer_requests(["test prompt"]) diff --git a/tests/unit/executor/promptgen/test_anecdoctor.py b/tests/unit/executor/promptgen/test_anecdoctor.py index 31d4667cad..3e40b00b87 100644 --- a/tests/unit/executor/promptgen/test_anecdoctor.py +++ b/tests/unit/executor/promptgen/test_anecdoctor.py @@ -538,3 +538,25 @@ def test_special_characters_in_data(self, mock_objective_target): result = generator._format_few_shot_examples(evaluation_data=evaluation_data) for data in evaluation_data: assert data in result + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_extract_knowledge_graph_raises_when_processing_model_is_none(): + """Test that _extract_knowledge_graph_async raises ValueError when processing model is None.""" + mock_target = MagicMock(spec=PromptChatTarget) + mock_target.get_identifier.return_value = ComponentIdentifier( + class_name="MockTarget", class_module="test_module" + ) + generator = AnecdoctorGenerator(objective_target=mock_target) + # Ensure processing model is explicitly None + assert generator._processing_model is None + + context = AnecdoctorContext( + evaluation_data=["sample data"], + language="english", + content_type="viral tweet", + ) + + with pytest.raises(ValueError, match="self._processing_model is not initialized"): + await generator._extract_knowledge_graph_async(context=context) diff --git a/tests/unit/executor/workflow/test_xpia.py b/tests/unit/executor/workflow/test_xpia.py index 2b90b213f6..0a5884acc3 100644 --- a/tests/unit/executor/workflow/test_xpia.py +++ b/tests/unit/executor/workflow/test_xpia.py @@ -615,3 +615,85 @@ def test_status_property_unknown(self) -> None: result = XPIAResult(processing_conversation_id="test-id", processing_response="test response", score=None) assert result.status == XPIAStatus.UNKNOWN + + +@pytest.mark.usefixtures("patch_central_database") +class TestXPIAGuards: + """Tests for type-narrowing guards in XPIA workflow.""" + + @pytest.mark.asyncio + async def test_execute_processing_raises_when_callback_is_none( + self, + ) -> None: + """Test that _execute_processing_async raises ValueError when processing_callback is None.""" + mock_target = MagicMock(spec=PromptTarget) + mock_target.get_identifier.return_value = ComponentIdentifier( + class_name="MockTarget", class_module="test_module" + ) + workflow = XPIAWorkflow(attack_setup_target=mock_target) + + attack_msg = Message( + message_pieces=[MessagePiece(role="user", original_value="attack content")] + ) + context = XPIAContext(attack_content=attack_msg, processing_callback=None) + + with pytest.raises(ValueError, match="processing_callback is not set"): + await workflow._execute_processing_async(context=context) + + @pytest.mark.asyncio + async def test_execute_processing_raises_when_memory_is_none( + self, + ) -> None: + """Test that _execute_processing_async raises RuntimeError when memory is None.""" + mock_target = MagicMock(spec=PromptTarget) + mock_target.get_identifier.return_value = ComponentIdentifier( + class_name="MockTarget", class_module="test_module" + ) + workflow = XPIAWorkflow(attack_setup_target=mock_target) + workflow._memory = None + + mock_callback = AsyncMock(return_value="response") + attack_msg = Message( + message_pieces=[MessagePiece(role="user", original_value="attack content")] + ) + context = XPIAContext(attack_content=attack_msg, processing_callback=mock_callback) + + with pytest.raises(RuntimeError, match="Memory not initialized"): + await workflow._execute_processing_async(context=context) + + @pytest.mark.asyncio + async def test_xpia_test_setup_raises_when_processing_prompt_is_none( + self, + ) -> None: + """Test that the process_async closure raises RuntimeError when processing_prompt is None.""" + from pyrit.executor.workflow.xpia import XPIATestWorkflow + + mock_target = MagicMock(spec=PromptTarget) + mock_target.get_identifier.return_value = ComponentIdentifier( + class_name="MockTarget", class_module="test_module" + ) + mock_processing_target = MagicMock(spec=PromptTarget) + mock_processing_target.get_identifier.return_value = ComponentIdentifier( + class_name="MockProcessingTarget", class_module="test_module" + ) + mock_scorer = MagicMock(spec=Scorer) + mock_scorer.get_identifier.return_value = ComponentIdentifier( + class_name="MockScorer", class_module="test_module" + ) + workflow = XPIATestWorkflow( + attack_setup_target=mock_target, + processing_target=mock_processing_target, + scorer=mock_scorer, + ) + + attack_msg = Message( + message_pieces=[MessagePiece(role="user", original_value="attack content")] + ) + context = XPIAContext(attack_content=attack_msg, processing_prompt=None) + + await workflow._setup_async(context=context) + + # The processing_callback should be set after _setup_async + assert context.processing_callback is not None + with pytest.raises(RuntimeError, match="context.processing_prompt is not initialized"): + await context.processing_callback() diff --git a/tests/unit/identifiers/test_component_identifier.py b/tests/unit/identifiers/test_component_identifier.py index 299f22933e..220de386a5 100644 --- a/tests/unit/identifiers/test_component_identifier.py +++ b/tests/unit/identifiers/test_component_identifier.py @@ -1333,3 +1333,12 @@ def test_mixed_children_with_and_without_eval_hash(self): children={"sub_scorers": [child_with, child_without]}, ) assert parent._collect_child_eval_hashes() == {"has_hash"} + + +def test_short_hash_raises_when_hash_none(): + obj = ComponentIdentifier.__new__(ComponentIdentifier) + object.__setattr__(obj, "hash", None) + object.__setattr__(obj, "class_name", "Test") + object.__setattr__(obj, "class_module", "test.module") + with pytest.raises(RuntimeError, match="hash should be set by __post_init__"): + obj.short_hash diff --git a/tests/unit/identifiers/test_evaluation_identifier.py b/tests/unit/identifiers/test_evaluation_identifier.py index 69eda9d489..e10d0f2f46 100644 --- a/tests/unit/identifiers/test_evaluation_identifier.py +++ b/tests/unit/identifiers/test_evaluation_identifier.py @@ -319,3 +319,12 @@ def test_eval_hash_preserved_through_double_roundtrip(self): # Second retrieve r2 = ComponentIdentifier.from_dict(d2) assert _StubEvaluationIdentifier(r2).eval_hash == correct_eval_hash + + +def test_compute_eval_hash_raises_when_hash_none_and_no_rules(): + identifier = ComponentIdentifier.__new__(ComponentIdentifier) + object.__setattr__(identifier, "hash", None) + object.__setattr__(identifier, "class_name", "Test") + object.__setattr__(identifier, "class_module", "test.module") + with pytest.raises(RuntimeError, match="hash should be set by __post_init__"): + compute_eval_hash(identifier, child_eval_rules={}) diff --git a/tests/unit/memory/memory_interface/test_interface_core.py b/tests/unit/memory/memory_interface/test_interface_core.py index dfa135377c..85255ccd65 100644 --- a/tests/unit/memory/memory_interface/test_interface_core.py +++ b/tests/unit/memory/memory_interface/test_interface_core.py @@ -1,8 +1,20 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import pytest + from pyrit.memory import MemoryInterface def test_memory(sqlite_instance: MemoryInterface): assert sqlite_instance + + +def test_print_schema_raises_when_engine_none(): + # Test the MemoryInterface.print_schema guard; use AzureSQLMemory which inherits it without override + from pyrit.memory import AzureSQLMemory + + obj = AzureSQLMemory.__new__(AzureSQLMemory) + obj.engine = None + with pytest.raises(RuntimeError, match="Engine is not initialized"): + obj.print_schema() diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index acf4420604..4d9e056c5b 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -6,6 +6,7 @@ from collections.abc import Generator, MutableSequence, Sequence from datetime import timezone from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch import pytest @@ -405,3 +406,48 @@ def test_update_prompt_metadata_by_conversation_id(memory_interface: AzureSQLMem with memory_interface.get_session() as session: # type: ignore[arg-type] updated_entry = session.query(PromptMemoryEntry).filter_by(conversation_id="123").first() assert updated_entry.prompt_metadata == {"updated": "updated"} + + +def test_refresh_token_if_needed_raises_when_expiry_none(): + obj = AzureSQLMemory.__new__(AzureSQLMemory) + obj._auth_token_expiry = None + with pytest.raises(RuntimeError, match="Auth token expiry not initialized"): + obj._refresh_token_if_needed() + + +def test_provide_token_raises_when_auth_token_none(): + obj = AzureSQLMemory.__new__(AzureSQLMemory) + obj._auth_token = None + obj._auth_token_expiry = 9999999999.0 + obj.engine = MagicMock() + + captured_fn = None + + def fake_listens_for(*args, **kwargs): + def decorator(fn): + nonlocal captured_fn + captured_fn = fn + return fn + + return decorator + + with patch("pyrit.memory.azure_sql_memory.event.listens_for", side_effect=fake_listens_for): + obj._enable_azure_authorization() + + assert captured_fn is not None + with pytest.raises(RuntimeError, match="Azure auth token is not initialized"): + captured_fn(None, None, ["some_connection_string"], {}) + + +def test_create_tables_if_not_exist_raises_when_engine_none(): + obj = AzureSQLMemory.__new__(AzureSQLMemory) + obj.engine = None + with pytest.raises(RuntimeError, match="Engine is not initialized"): + obj._create_tables_if_not_exist() + + +def test_reset_database_raises_when_engine_none(): + obj = AzureSQLMemory.__new__(AzureSQLMemory) + obj.engine = None + with pytest.raises(RuntimeError, match="Engine is not initialized"): + obj.reset_database() diff --git a/tests/unit/memory/test_sqlite_memory.py b/tests/unit/memory/test_sqlite_memory.py index ba07356578..71ed421e8c 100644 --- a/tests/unit/memory/test_sqlite_memory.py +++ b/tests/unit/memory/test_sqlite_memory.py @@ -699,3 +699,21 @@ def test_create_engine_uses_static_pool_for_in_memory(sqlite_instance): from sqlalchemy.pool import StaticPool assert isinstance(sqlite_instance.engine.pool, StaticPool) + + +def test_create_tables_raises_when_engine_none(): + from pyrit.memory import SQLiteMemory + + obj = SQLiteMemory.__new__(SQLiteMemory) + obj.engine = None + with pytest.raises(RuntimeError, match="Engine is not initialized"): + obj._create_tables_if_not_exist() + + +def test_reset_database_raises_when_engine_none(): + from pyrit.memory import SQLiteMemory + + obj = SQLiteMemory.__new__(SQLiteMemory) + obj.engine = None + with pytest.raises(RuntimeError, match="Engine is not initialized"): + obj.reset_database() diff --git a/tests/unit/models/test_data_type_serializer.py b/tests/unit/models/test_data_type_serializer.py index 88c39e5562..2984166e3c 100644 --- a/tests/unit/models/test_data_type_serializer.py +++ b/tests/unit/models/test_data_type_serializer.py @@ -6,7 +6,7 @@ import re import tempfile from typing import get_args -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch import pytest from PIL import Image @@ -368,3 +368,71 @@ async def test_binary_path_subdirectory(sqlite_instance): serializer = data_serializer_factory(category="prompt-memory-entries", data_type="binary_path") await serializer.save_data(b"test data") assert "/binaries/" in serializer.value or "\\binaries\\" in serializer.value + + +def test_get_storage_io_raises_when_results_storage_io_none(): + serializer = data_serializer_factory(category="prompt-memory-entries", data_type="image_path") + serializer.value = "https://account.blob.core.windows.net/container/path/image.png" + mock_memory = MagicMock() + mock_memory.results_storage_io = None + with patch.object(type(serializer), "_memory", new_callable=PropertyMock, return_value=mock_memory): + with pytest.raises(RuntimeError, match="results_storage_io is not configured"): + serializer._get_storage_io() + + +@pytest.mark.asyncio +async def test_save_data_raises_when_results_storage_io_none(): + serializer = data_serializer_factory(category="prompt-memory-entries", data_type="image_path") + mock_memory = MagicMock() + mock_memory.results_storage_io = None + with patch.object(type(serializer), "_memory", new_callable=PropertyMock, return_value=mock_memory): + with patch.object(serializer, "get_data_filename", new_callable=AsyncMock, return_value="local/path/img.png"): + with pytest.raises(RuntimeError, match="Storage IO not initialized"): + await serializer.save_data(b"\x89PNG") + + +@pytest.mark.asyncio +async def test_save_b64_image_raises_when_results_storage_io_none(): + serializer = data_serializer_factory(category="prompt-memory-entries", data_type="image_path") + mock_memory = MagicMock() + mock_memory.results_storage_io = None + with patch.object(type(serializer), "_memory", new_callable=PropertyMock, return_value=mock_memory): + with patch.object(serializer, "get_data_filename", new_callable=AsyncMock, return_value="local/path/img.png"): + import base64 + + b64_data = base64.b64encode(b"\x89PNG").decode() + with pytest.raises(RuntimeError, match="Storage IO not initialized"): + await serializer.save_b64_image(b64_data) + + +@pytest.mark.asyncio +async def test_save_formatted_audio_raises_when_results_storage_io_none(): + from pyrit.models import data_serializer_factory as factory + + serializer = factory(category="prompt-memory-entries", data_type="audio_path") + mock_memory = MagicMock() + mock_memory.results_storage_io = None + azure_url = "https://account.blob.core.windows.net/container/audio/test.wav" + with patch.object(type(serializer), "_memory", new_callable=PropertyMock, return_value=mock_memory): + with patch.object(serializer, "get_data_filename", new_callable=AsyncMock, return_value=azure_url): + with patch("wave.open"): + with patch("aiofiles.open", new_callable=MagicMock) as mock_aio: + mock_file = MagicMock() + mock_file.__aenter__ = AsyncMock(return_value=mock_file) + mock_file.__aexit__ = AsyncMock(return_value=False) + mock_file.read = AsyncMock(return_value=b"audio_bytes") + mock_aio.return_value = mock_file + with pytest.raises(RuntimeError, match="results_storage_io is not initialized"): + await serializer.save_formatted_audio(data=b"\x00\x01\x02") + + +@pytest.mark.asyncio +async def test_get_data_filename_raises_when_results_storage_io_none(): + serializer = data_serializer_factory(category="prompt-memory-entries", data_type="image_path") + serializer._file_path = None + mock_memory = MagicMock() + mock_memory.results_storage_io = None + mock_memory.results_path = "/local/results" + with patch.object(type(serializer), "_memory", new_callable=PropertyMock, return_value=mock_memory): + with pytest.raises(RuntimeError, match="results_storage_io is not initialized"): + await serializer.get_data_filename() diff --git a/tests/unit/models/test_seed_attack_group.py b/tests/unit/models/test_seed_attack_group.py index 4321a7fbb9..ecfb2959e3 100644 --- a/tests/unit/models/test_seed_attack_group.py +++ b/tests/unit/models/test_seed_attack_group.py @@ -65,3 +65,14 @@ def test_seed_attack_group_with_multiple_prompts(): p2 = _make_prompt(value="p2", sequence=1) group = SeedAttackGroup(seeds=[objective, p1, p2]) assert len(group.prompts) == 2 + + +def test_seed_attack_group_objective_raises_when_get_objective_returns_none(): + from unittest.mock import patch + + prompt = _make_prompt() + objective = _make_objective() + group = SeedAttackGroup(seeds=[objective, prompt]) + with patch.object(type(group), "_get_objective", return_value=None): + with pytest.raises(ValueError, match="SeedAttackGroup should always have an objective"): + group.objective diff --git a/tests/unit/models/test_storage_io.py b/tests/unit/models/test_storage_io.py index 0159d65b91..674324564c 100644 --- a/tests/unit/models/test_storage_io.py +++ b/tests/unit/models/test_storage_io.py @@ -301,3 +301,11 @@ def test_resolve_blob_name_with_path_object(azure_blob_storage_io): result = azure_blob_storage_io._resolve_blob_name(PurePosixPath("dir1/dir2/file.txt")) assert result == "dir1/dir2/file.txt" + + +@pytest.mark.asyncio +async def test_upload_blob_raises_when_client_async_none(): + obj = AzureBlobStorageIO.__new__(AzureBlobStorageIO) + obj._client_async = None + with pytest.raises(RuntimeError, match="Azure container client not initialized"): + await obj._upload_blob_async(file_name="test.txt", data=b"data", content_type="text/plain") diff --git a/tests/unit/prompt_converter/test_add_image_video_converter.py b/tests/unit/prompt_converter/test_add_image_video_converter.py index ec297fcd3f..03d367a3cc 100644 --- a/tests/unit/prompt_converter/test_add_image_video_converter.py +++ b/tests/unit/prompt_converter/test_add_image_video_converter.py @@ -106,3 +106,38 @@ async def test_add_image_video_converter_convert_async(video_converter_sample_vi os.remove(video_converter_sample_video) os.remove(video_converter_sample_image) os.remove("output_video.mp4") + + +@pytest.mark.skipif(not is_opencv_installed(), reason="opencv is not installed") +@pytest.mark.asyncio +async def test_add_image_to_video_raises_when_decode_returns_none(video_converter_sample_video): + """Guard at line 146: cv2.imdecode returns None raises ValueError.""" + from unittest.mock import AsyncMock, patch + + converter = AddImageVideoConverter(video_path=video_converter_sample_video, output_path="output_video.mp4") + + # Mock the data serializer to return invalid image bytes (not a valid image) + mock_image_serializer = AsyncMock() + mock_image_serializer.read_data = AsyncMock(return_value=b"not_valid_image_data") + mock_image_serializer._is_azure_storage_url = lambda x: False + + mock_video_serializer = AsyncMock() + with open(video_converter_sample_video, "rb") as f: + video_bytes = f.read() + mock_video_serializer.read_data = AsyncMock(return_value=video_bytes) + mock_video_serializer._is_azure_storage_url = lambda x: False + + def factory_side_effect(*, category, data_type, value): + if data_type == "image_path": + return mock_image_serializer + return mock_video_serializer + + with patch( + "pyrit.prompt_converter.add_image_to_video_converter.data_serializer_factory", + side_effect=factory_side_effect, + ): + with pytest.raises(ValueError, match="Failed to decode overlay image"): + await converter._add_image_to_video( + image_path="fake_image.png", output_path="output_video_test.mp4" + ) + os.remove(video_converter_sample_video) diff --git a/tests/unit/prompt_normalizer/test_prompt_normalizer.py b/tests/unit/prompt_normalizer/test_prompt_normalizer.py index 6386a1024a..24f0b3cf26 100644 --- a/tests/unit/prompt_normalizer/test_prompt_normalizer.py +++ b/tests/unit/prompt_normalizer/test_prompt_normalizer.py @@ -117,14 +117,17 @@ async def test_send_prompt_async_multiple_converters(mock_memory_instance, seed_ @pytest.mark.asyncio -async def test_send_prompt_async_no_response_adds_memory(mock_memory_instance, seed_group): +async def test_send_prompt_async_no_response_raises_empty_response(mock_memory_instance, seed_group): prompt_target = AsyncMock() prompt_target.send_prompt_async = AsyncMock(return_value=None) normalizer = PromptNormalizer() message = Message.from_prompt(prompt=seed_group.prompts[0].value, role="user") - await normalizer.send_prompt_async(message=message, target=prompt_target) + with pytest.raises(EmptyResponseException): + await normalizer.send_prompt_async(message=message, target=prompt_target) + + # Request should still be added to memory before the exception assert mock_memory_instance.add_message_to_memory.call_count == 1 request = mock_memory_instance.add_message_to_memory.call_args[1]["request"] @@ -556,3 +559,11 @@ async def test_convert_values_context_includes_converter_identifier(self, mock_m assert captured is not None assert captured.component_identifier is not None assert "ContextCapturingConverter" in str(captured.component_identifier) + + +def test_memory_property_raises_when_memory_none(): + """Guard at line 45: _memory is None raises RuntimeError.""" + normalizer = PromptNormalizer.__new__(PromptNormalizer) + normalizer._memory = None + with pytest.raises(RuntimeError, match="Memory is not initialized"): + normalizer.memory diff --git a/tests/unit/prompt_target/target/test_none_guard_openai_target.py b/tests/unit/prompt_target/target/test_none_guard_openai_target.py new file mode 100644 index 0000000000..acfad74ed6 --- /dev/null +++ b/tests/unit/prompt_target/target/test_none_guard_openai_target.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from openai import BadRequestError, ContentFilterFinishReasonError + +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target import OpenAIChatTarget + + +def test_client_property_raises_when_async_client_none(patch_central_database): + target = OpenAIChatTarget(endpoint="https://test.openai.com", api_key="test", model_name="gpt-4") + target._async_client = None + with pytest.raises(RuntimeError, match="AsyncOpenAI client is not initialized"): + target._client + + +@pytest.mark.asyncio +async def test_handle_openai_request_raises_when_no_message_pieces(patch_central_database): + """The try-block guard (line 442) raises when request has no message_pieces.""" + target = OpenAIChatTarget(endpoint="https://test.openai.com", api_key="test", model_name="gpt-4") + empty_request = MagicMock(spec=Message) + empty_request.message_pieces = [] + + api_call = AsyncMock(return_value=MagicMock()) + + with pytest.raises(ValueError, match="No message pieces in request"): + await target._handle_openai_request(api_call=api_call, request=empty_request) + + +@pytest.mark.asyncio +async def test_handle_openai_request_content_filter_error_raises_when_no_message_pieces(patch_central_database): + """The ContentFilterFinishReasonError handler (line 470) raises when request has no pieces.""" + target = OpenAIChatTarget(endpoint="https://test.openai.com", api_key="test", model_name="gpt-4") + empty_request = MagicMock(spec=Message) + empty_request.message_pieces = [] + + api_call = AsyncMock( + side_effect=ContentFilterFinishReasonError(), + ) + + with pytest.raises(ValueError, match="No message pieces in request"): + await target._handle_openai_request(api_call=api_call, request=empty_request) + + +@pytest.mark.asyncio +async def test_handle_openai_request_bad_request_error_raises_when_no_message_pieces(patch_central_database): + """The BadRequestError handler (line 490) raises when request has no pieces.""" + target = OpenAIChatTarget(endpoint="https://test.openai.com", api_key="test", model_name="gpt-4") + empty_request = MagicMock(spec=Message) + empty_request.message_pieces = [] + + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.json.return_value = {"error": {"message": "bad request", "code": "invalid_request"}} + mock_response.headers = {} + + api_call = AsyncMock( + side_effect=BadRequestError( + message="bad request", + response=mock_response, + body={"error": {"message": "bad request", "code": "invalid_request"}}, + ), + ) + + with pytest.raises(ValueError, match="No message pieces in request"): + await target._handle_openai_request(api_call=api_call, request=empty_request) diff --git a/tests/unit/prompt_target/target/test_prompt_shield_target.py b/tests/unit/prompt_target/target/test_prompt_shield_target.py index 37cfdfd7ae..4443080f62 100644 --- a/tests/unit/prompt_target/target/test_prompt_shield_target.py +++ b/tests/unit/prompt_target/target/test_prompt_shield_target.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from collections.abc import MutableSequence -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from unit.mocks import get_audio_message_piece, get_sample_conversations @@ -110,3 +110,20 @@ def test_token_provider_authentication(): assert target is not None assert target._api_key == token_provider assert callable(target._api_key) + + +def test_init_raises_when_endpoint_none(): + """Guard at line 98: endpoint_value is None raises ValueError.""" + with patch("pyrit.prompt_target.prompt_shield_target.default_values") as mock_dv: + mock_dv.get_required_value = MagicMock(return_value=None) + with pytest.raises(ValueError, match="Endpoint value is required"): + PromptShieldTarget(endpoint=None, api_key="test_key") + + +def test_init_raises_when_api_key_none(sqlite_instance): + """Guard at line 113: _api_key_value is None raises ValueError.""" + with patch("pyrit.prompt_target.prompt_shield_target.default_values") as mock_dv: + # First call for endpoint returns valid, second call for api_key returns None + mock_dv.get_required_value = MagicMock(side_effect=["https://test.endpoint.com", None]) + with pytest.raises(ValueError, match="API key is required"): + PromptShieldTarget(endpoint=None, api_key=None) diff --git a/tests/unit/prompt_target/target/test_prompt_target_azure_blob_storage.py b/tests/unit/prompt_target/target/test_prompt_target_azure_blob_storage.py index effc4ba854..736c80cf9f 100644 --- a/tests/unit/prompt_target/target/test_prompt_target_azure_blob_storage.py +++ b/tests/unit/prompt_target/target/test_prompt_target_azure_blob_storage.py @@ -146,3 +146,18 @@ async def test_send_prompt_async( assert azure_blob_storage_target._container_url in blob_url assert blob_url.endswith(".txt") mock_upload_blob.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_upload_blob_async_raises_when_client_async_none(azure_blob_storage_target: AzureBlobStorageTarget): + """Guard at line 169: _client_async is None after _create_container_client_async still leaves it None.""" + azure_blob_storage_target._client_async = None + with patch.object( + AzureBlobStorageTarget, "_create_container_client_async", new_callable=AsyncMock + ): + # After the mock _create_container_client_async, _client_async remains None + with patch.object(AzureBlobStorageTarget, "_parse_url", return_value=("container", "")): + with pytest.raises(RuntimeError, match="Blob storage client not initialized"): + await azure_blob_storage_target._upload_blob_async( + file_name="test.txt", data=b"hello", content_type="text/plain" + ) diff --git a/tests/unit/prompt_target/target/test_video_target.py b/tests/unit/prompt_target/target/test_video_target.py index faba7830a0..25a5ecbad3 100644 --- a/tests/unit/prompt_target/target/test_video_target.py +++ b/tests/unit/prompt_target/target/test_video_target.py @@ -1171,3 +1171,23 @@ def test_remix_raises_when_video_path_missing_video_id(self, video_target: OpenA with pytest.raises(ValueError, match="video_path piece is missing.*video_id"): OpenAIVideoTarget._validate_video_remix_pieces(message=message) + + +@pytest.mark.asyncio +async def test_send_prompt_async_raises_when_no_text_piece(patch_central_database): + """Guard at line 210: text_piece is None raises ValueError.""" + target = OpenAIVideoTarget( + endpoint="https://api.openai.com/v1", + api_key="test", + model_name="sora-2", + ) + msg = MessagePiece( + role="user", + original_value="/path/image.png", + converted_value="/path/image.png", + converted_value_data_type="image_path", + ) + message = Message([msg]) + with patch.object(target, "_validate_request"): + with pytest.raises(ValueError, match="No text piece found in message"): + await target.send_prompt_async(message=message) diff --git a/tests/unit/scenario/test_scenario.py b/tests/unit/scenario/test_scenario.py index 7f02982015..dd389b7b54 100644 --- a/tests/unit/scenario/test_scenario.py +++ b/tests/unit/scenario/test_scenario.py @@ -841,3 +841,17 @@ def test_returns_fallback_when_registry_empty(self, mock_registry_cls, mock_oai_ result = Scenario._get_default_objective_scorer(MagicMock()) assert isinstance(result, TrueFalseInverterScorer) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_execute_scenario_raises_when_scenario_result_id_is_none(): + """Test that _execute_scenario_async raises ValueError when _scenario_result_id is None.""" + scenario = ConcreteScenario.__new__(ConcreteScenario) + scenario._scenario_result_id = None + scenario._name = "test_scenario" + scenario._atomic_attacks = [] + scenario._memory = MagicMock() + + with pytest.raises(ValueError, match="self._scenario_result_id is not initialized"): + await scenario._execute_scenario_async() diff --git a/tests/unit/score/test_azure_content_filter.py b/tests/unit/score/test_azure_content_filter.py index 9cd391398d..d997c00d3c 100644 --- a/tests/unit/score/test_azure_content_filter.py +++ b/tests/unit/score/test_azure_content_filter.py @@ -295,3 +295,13 @@ async def test_evaluate_async_sets_file_mapping_for_single_category(patch_centra # Parent evaluate_async should be called mock_eval.assert_called_once() + + +def test_init_raises_runtime_error_when_api_key_not_string(): + """Test that __init__ raises RuntimeError when resolved api_key is neither callable nor string.""" + with patch( + "pyrit.score.float_scale.azure_content_filter_scorer.ensure_async_token_provider", + return_value=12345, + ): + with pytest.raises(RuntimeError, match="Expected string API key"): + AzureContentFilterScorer(api_key="foo", endpoint="https://example.com") diff --git a/tests/unit/score/test_general_float_scale_scorer.py b/tests/unit/score/test_general_float_scale_scorer.py index 7ee85404d5..9d9d59a2d6 100644 --- a/tests/unit/score/test_general_float_scale_scorer.py +++ b/tests/unit/score/test_general_float_scale_scorer.py @@ -158,3 +158,32 @@ def test_general_float_scorer_init_invalid_min_max(): min_value=10, max_value=5, ) + + +def test_get_scorer_metrics_returns_none_when_eval_hash_is_none(patch_central_database): + """Test that get_scorer_metrics returns None when eval_hash is None.""" + from unittest.mock import patch as _patch + + from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer + from pyrit.score.scorer_evaluation.scorer_evaluator import ScorerEvalDatasetFiles + + chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") + + scorer = SelfAskGeneralFloatScaleScorer( + chat_target=chat_target, + system_prompt_format_string="Prompt.", + category="test_category", + ) + # Set evaluation_file_mapping with harm_category so the early return before eval_hash is bypassed + scorer.evaluation_file_mapping = ScorerEvalDatasetFiles( + human_labeled_datasets_files=["harm/*.csv"], + result_file="harm/test_metrics.jsonl", + harm_category="hate_speech", + ) + # Mock get_identifier to return an identifier with eval_hash=None + mock_identifier = MagicMock() + mock_identifier.eval_hash = None + with _patch.object(scorer, "get_identifier", return_value=mock_identifier): + result = scorer.get_scorer_metrics() + assert result is None diff --git a/tests/unit/score/test_general_true_false_scorer.py b/tests/unit/score/test_general_true_false_scorer.py index 49e4b98397..e7130167e2 100644 --- a/tests/unit/score/test_general_true_false_scorer.py +++ b/tests/unit/score/test_general_true_false_scorer.py @@ -114,3 +114,22 @@ async def test_general_scorer_score_async_handles_custom_keys(patch_central_data assert score[0].score_value == "false" assert "This is the rationale." in score[0].score_rationale assert "This is the description." in score[0].score_value_description + + +def test_true_false_get_scorer_metrics_returns_none_when_eval_hash_is_none(patch_central_database): + """Test that TrueFalseScorer.get_scorer_metrics returns None when eval_hash is None.""" + from unittest.mock import patch as _patch + + from pyrit.score.true_false.self_ask_true_false_scorer import ( + SelfAskTrueFalseScorer, + ) + + chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") + + scorer = SelfAskTrueFalseScorer(chat_target=chat_target) + mock_identifier = MagicMock() + mock_identifier.eval_hash = None + with _patch.object(scorer, "get_identifier", return_value=mock_identifier): + result = scorer.get_scorer_metrics() + assert result is None \ No newline at end of file diff --git a/tests/unit/score/test_scorer_evaluator.py b/tests/unit/score/test_scorer_evaluator.py index cf1d379d66..c8bcb65309 100644 --- a/tests/unit/score/test_scorer_evaluator.py +++ b/tests/unit/score/test_scorer_evaluator.py @@ -818,3 +818,41 @@ async def test_run_evaluation_async_raises_when_harm_csv_missing_harm_definition num_scorer_trials=1, update_registry_behavior=RegistryUpdateBehavior.NEVER_UPDATE, ) + + +def test_should_skip_evaluation_returns_false_when_eval_hash_is_none(tmp_path): + """Test that _should_skip_evaluation returns (False, None) when scorer eval_hash is None.""" + scorer = MagicMock(spec=TrueFalseScorer) + mock_identifier = MagicMock() + mock_identifier.eval_hash = None + scorer.get_identifier = MagicMock(return_value=mock_identifier) + + evaluator = ObjectiveScorerEvaluator(scorer=scorer) + result_file = tmp_path / "test_results.jsonl" + + should_skip, result = evaluator._should_skip_evaluation( + dataset_version="1.0", + num_scorer_trials=3, + harm_category=None, + result_file_path=result_file, + ) + + assert should_skip is False + assert result is None + + +@patch("pyrit.score.scorer_evaluation.scorer_evaluator.replace_evaluation_results") +def test_write_metrics_to_registry_returns_early_when_eval_hash_is_none(mock_replace, tmp_path): + """Test that _write_metrics_to_registry returns early when scorer eval_hash is None.""" + scorer = MagicMock(spec=FloatScaleScorer) + mock_identifier = MagicMock() + mock_identifier.eval_hash = None + scorer.get_identifier = MagicMock(return_value=mock_identifier) + + evaluator = HarmScorerEvaluator(scorer=scorer) + result_file = tmp_path / "test_results.jsonl" + + metrics = MagicMock() + evaluator._write_metrics_to_registry(metrics=metrics, result_file_path=result_file) + + mock_replace.assert_not_called() diff --git a/tests/unit/score/test_self_ask_true_false.py b/tests/unit/score/test_self_ask_true_false.py index 17a10048c6..7d1921eb06 100644 --- a/tests/unit/score/test_self_ask_true_false.py +++ b/tests/unit/score/test_self_ask_true_false.py @@ -251,3 +251,16 @@ def test_self_ask_true_false_with_path_and_question(patch_central_database): true_false_question_path=TrueFalseQuestionPaths.GROUNDED.value, true_false_question=custom_question, ) + + +def test_self_ask_true_false_raises_when_yaml_loads_none(patch_central_database): + """Test that ValueError is raised when YAML file loads as None.""" + chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") + + with patch("pyrit.score.true_false.self_ask_true_false_scorer.yaml.safe_load", return_value=None): + with pytest.raises(ValueError, match="Failed to load true_false_question YAML"): + SelfAskTrueFalseScorer( + chat_target=chat_target, + true_false_question_path=TrueFalseQuestionPaths.GROUNDED.value, + ) diff --git a/tests/unit/score/test_true_false_composite_scorer.py b/tests/unit/score/test_true_false_composite_scorer.py index 3824a95574..c434902412 100644 --- a/tests/unit/score/test_true_false_composite_scorer.py +++ b/tests/unit/score/test_true_false_composite_scorer.py @@ -184,3 +184,17 @@ def test_composite_scorer_empty_scorers_list(): """Test that TrueFalseCompositeScorer raises an exception when given an empty list of scorers.""" with pytest.raises(ValueError, match="At least one scorer must be provided"): TrueFalseCompositeScorer(aggregator=TrueFalseScoreAggregator.AND, scorers=[]) + + +@pytest.mark.asyncio +async def test_composite_scorer_raises_when_message_piece_id_is_none(true_scorer, patch_central_database): + """Test that _score_async raises ValueError when message piece has no ID.""" + scorer = TrueFalseCompositeScorer(aggregator=TrueFalseScoreAggregator.AND, scorers=[true_scorer]) + + # Create a message with a piece whose id is None + piece = MessagePiece(role="user", original_value="test content") + piece.id = None + message = piece.to_message() + + with pytest.raises(RuntimeError, match="Message piece must have an ID"): + await scorer.score_async(message) diff --git a/tests/unit/setup/test_airt_initializer.py b/tests/unit/setup/test_airt_initializer.py index 61f74cbe57..95d96c90a4 100644 --- a/tests/unit/setup/test_airt_initializer.py +++ b/tests/unit/setup/test_airt_initializer.py @@ -247,3 +247,45 @@ async def test_get_info_includes_description(self): assert "description" in info assert isinstance(info["description"], str) assert len(info["description"]) > 0 + + +@pytest.mark.asyncio +async def test_initialize_async_raises_when_converter_endpoint_is_none(): + """Test that initialize_async raises ValueError when converter_endpoint env var is None.""" + init = AIRTInitializer() + with ( + patch.object(init, "_validate_operation_fields"), + patch.dict( + "os.environ", + { + "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2": "https://test.openai.azure.com", + "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2": "gpt-4", + }, + clear=False, + ), + patch.dict("os.environ", {"AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT": ""}, clear=False), + ): + # Remove the key to force None + os.environ.pop("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT", None) + with pytest.raises(ValueError, match="converter_endpoint is not initialized"): + await init.initialize_async() + + +@pytest.mark.asyncio +async def test_initialize_async_raises_when_scorer_endpoint_is_none(): + """Test that initialize_async raises ValueError when scorer_endpoint env var is None.""" + init = AIRTInitializer() + with ( + patch.object(init, "_validate_operation_fields"), + patch.dict( + "os.environ", + { + "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT": "https://test.openai.azure.com", + "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL": "gpt-4", + }, + clear=False, + ), + ): + os.environ.pop("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2", None) + with pytest.raises(ValueError, match="scorer_endpoint is not initialized"): + await init.initialize_async()