From fa5bc5ca1f7c6f1c39ba8f13e503e47ee5469eab Mon Sep 17 00:00:00 2001 From: Christian Heimes Date: Thu, 7 May 2026 16:37:47 +0200 Subject: [PATCH] feat: add protocol for override hooks OverrideHookProtocol documents the interface for per-package override hooks and provides runtime validation of hook signatures via check_signature(). The lint command now checks override hook signatures. Co-Authored-By: Claude Signed-off-by: Christian Heimes --- src/fromager/commands/lint.py | 16 ++ src/fromager/overrides.py | 272 ++++++++++++++++++++++++++++++++++ tests/test_overrides.py | 49 ++++++ 3 files changed, 337 insertions(+) diff --git a/src/fromager/commands/lint.py b/src/fromager/commands/lint.py index 5ac8172b..eb061574 100644 --- a/src/fromager/commands/lint.py +++ b/src/fromager/commands/lint.py @@ -61,5 +61,21 @@ def lint( errors += 1 logger.error(f"ERROR: plugin name {name} should be {expected_name}") + logger.info("Checking override hook signatures...") + hook_names = overrides.OverrideHookProtocol.list_hooks() + for ext in exts: + mod = ext.plugin + for hook_name in hook_names: + func = getattr(mod, hook_name, None) + if func is None: + continue + try: + overrides.OverrideHookProtocol.check_signature( + func, hook_name=hook_name + ) + except TypeError as e: + errors += 1 + logger.error(f"ERROR: override {ext.name}.{hook_name}: {e}") + if errors: raise SystemExit(f"Found {errors} errors") diff --git a/src/fromager/overrides.py b/src/fromager/overrides.py index 2a95b1a3..3971c13b 100644 --- a/src/fromager/overrides.py +++ b/src/fromager/overrides.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import importlib import inspect import logging import pathlib @@ -6,8 +9,14 @@ from packaging.requirements import Requirement from packaging.utils import canonicalize_name +from packaging.version import Version from stevedore import extension +if typing.TYPE_CHECKING: + from . import build_environment, context + from .requirements_file import RequirementType + from .resolver import BaseProvider + # An interface for reretrieving per-package information which influences # the build process for a particular package - i.e. for a given package # and build target, what patches should we apply, what environment variables @@ -134,3 +143,266 @@ def find_override_method(distname: str, method: str) -> typing.Callable | None: return None logger.info("%s: found %s override", distname, method) return typing.cast(typing.Callable, getattr(mod, method)) + + +_F = typing.TypeVar("_F", bound=typing.Callable[..., typing.Any]) + + +def _default_hook(module: str, func: str) -> typing.Callable[[_F], _F]: + """Decorator that annotates a Protocol method with its default implementation. + + Stores a ``fromager_default`` attribute as a ``(module, func)`` tuple + on the decorated function so the mapping from hook name to default can + be discovered at runtime. + """ + + def decorator(fn: _F) -> _F: + fn.fromager_default = (module, func) # type: ignore[attr-defined] + return fn + + return decorator + + +class OverrideHookProtocol(typing.Protocol): + """Protocol defining the interface for per-package override hooks. + + Override modules may implement any subset of these methods to customize + the build process for a specific package. See the default implementations + for each hook's behavior when no override is provided. + """ + + @classmethod + def list_hooks(cls) -> list[str]: + """Return a list of hook names defined on this Protocol.""" + return [ + name for name, obj in vars(cls).items() if hasattr(obj, "fromager_default") + ] + + @classmethod + def get_default(cls, hook_name: str) -> typing.Callable[..., typing.Any]: + """Return the default function object for a hook name.""" + obj = getattr(cls, hook_name, None) + if obj is None or not hasattr(obj, "fromager_default"): + raise KeyError(hook_name) + module_name, func_name = obj.fromager_default + module = importlib.import_module(module_name) + return typing.cast(typing.Callable[..., typing.Any], getattr(module, func_name)) + + @classmethod + def check_signature( + cls, + func: typing.Callable[..., typing.Any], + *, + hook_name: str | None = None, + ) -> None: + """Check that a function's argument names match the protocol method. + + Only argument names are compared; the check ignores whether arguments + are positional or keyword-only because all hooks are called with + keyword arguments. + """ + if hook_name is None: + hook_name = func.__name__ + proto_method = getattr(cls, hook_name, None) + if proto_method is None or not hasattr(proto_method, "fromager_default"): + raise KeyError(hook_name) + proto_spec = inspect.getfullargspec(proto_method) + # Skip 'self' (first parameter of a protocol method) + expected_args = set(proto_spec.args[1:] + proto_spec.kwonlyargs) + func_spec = inspect.getfullargspec(func) + func_args = set(func_spec.args + func_spec.kwonlyargs) + if expected_args != func_args: + raise TypeError( + f"{hook_name}: argument names mismatch: " + f"expected {sorted(expected_args)}, got {sorted(func_args)}" + ) + + @_default_hook("fromager.wheels", "default_add_extra_metadata_to_wheels") + def add_extra_metadata_to_wheels( + self, + ctx: context.WorkContext, + req: Requirement, + version: Version, + extra_environ: dict[str, str], + sdist_root_dir: pathlib.Path, + dist_info_dir: pathlib.Path, + ) -> dict[str, typing.Any]: + """Add extra metadata files to built wheels. + + Default: :func:`fromager.wheels.default_add_extra_metadata_to_wheels` + """ + + @_default_hook("fromager.sources", "default_build_sdist") + def build_sdist( + self, + ctx: context.WorkContext, + extra_environ: dict, + req: Requirement, + version: Version, + sdist_root_dir: pathlib.Path, + build_env: build_environment.BuildEnvironment, + build_dir: pathlib.Path, + ) -> pathlib.Path: + """Build an sdist from the prepared source tree. + + Default: :func:`fromager.sources.default_build_sdist` + """ + + @_default_hook("fromager.wheels", "default_build_wheel") + def build_wheel( + self, + ctx: context.WorkContext, + build_env: build_environment.BuildEnvironment, + extra_environ: dict[str, str], + req: Requirement, + sdist_root_dir: pathlib.Path, + version: Version, + build_dir: pathlib.Path, + ) -> pathlib.Path: + """Build a wheel from the prepared source tree. + + Default: :func:`fromager.wheels.default_build_wheel` + """ + + @_default_hook("fromager.sources", "default_download_source") + def download_source( + self, + ctx: context.WorkContext, + req: Requirement, + version: Version, + download_url: str, + sdists_downloads_dir: pathlib.Path, + ) -> pathlib.Path: + """Download the source archive for a requirement. + + Default: :func:`fromager.sources.default_download_source` + """ + + @_default_hook("fromager.finders", "default_expected_source_archive_name") + def expected_source_archive_name( + self, + ctx: context.WorkContext, + req: Requirement, + dist_version: str, + ) -> str | None: + """Return the expected filename for a source archive. + + Default: :func:`fromager.finders.default_expected_source_archive_name` + """ + + @_default_hook("fromager.finders", "default_expected_source_directory_name") + def expected_source_directory_name( + self, + req: Requirement, + dist_version: str, + ) -> str: + """Return the expected directory name after unpacking a source archive. + + Default: :func:`fromager.finders.default_expected_source_directory_name` + """ + + @_default_hook("fromager.dependencies", "default_get_build_backend_dependencies") + def get_build_backend_dependencies( + self, + ctx: context.WorkContext, + req: Requirement, + sdist_root_dir: pathlib.Path, + build_dir: pathlib.Path, + extra_environ: dict[str, str], + build_env: build_environment.BuildEnvironment, + ) -> typing.Iterable[str]: + """Get build backend dependencies (PEP 517 get_requires_for_build_wheel). + + Default: :func:`fromager.dependencies.default_get_build_backend_dependencies` + """ + + @_default_hook("fromager.dependencies", "default_get_build_sdist_dependencies") + def get_build_sdist_dependencies( + self, + ctx: context.WorkContext, + req: Requirement, + sdist_root_dir: pathlib.Path, + build_dir: pathlib.Path, + extra_environ: dict[str, str], + build_env: build_environment.BuildEnvironment, + ) -> typing.Iterable[str]: + """Get build sdist dependencies. + + Default: :func:`fromager.dependencies.default_get_build_sdist_dependencies` + """ + + @_default_hook("fromager.dependencies", "default_get_build_system_dependencies") + def get_build_system_dependencies( + self, + ctx: context.WorkContext, + req: Requirement, + sdist_root_dir: pathlib.Path, + build_dir: pathlib.Path, + ) -> typing.Iterable[str]: + """Get build system dependencies from pyproject.toml [build-system] requires. + + Default: :func:`fromager.dependencies.default_get_build_system_dependencies` + """ + + @_default_hook("fromager.dependencies", "default_get_install_dependencies_of_sdist") + def get_install_dependencies_of_sdist( + self, + *, + ctx: context.WorkContext, + req: Requirement, + version: Version, + sdist_root_dir: pathlib.Path, + build_env: build_environment.BuildEnvironment, + extra_environ: dict[str, str], + build_dir: pathlib.Path, + config_settings: dict[str, str], + ) -> set[Requirement]: + """Get install dependencies (Requires-Dist) from source distribution. + + Default: :func:`fromager.dependencies.default_get_install_dependencies_of_sdist` + """ + + @_default_hook("fromager.resolver", "default_resolver_provider") + def get_resolver_provider( + self, + ctx: context.WorkContext, + req: Requirement, + sdist_server_url: str, + include_sdists: bool, + include_wheels: bool, + req_type: RequirementType | None = None, + ignore_platform: bool = False, + ) -> BaseProvider: + """Return a resolver provider for resolving package versions. + + Default: :func:`fromager.resolver.default_resolver_provider` + """ + + @_default_hook("fromager.sources", "default_prepare_source") + def prepare_source( + self, + ctx: context.WorkContext, + req: Requirement, + source_filename: pathlib.Path, + version: Version, + ) -> tuple[pathlib.Path, bool]: + """Unpack, modify, and prepare source for building. + + Default: :func:`fromager.sources.default_prepare_source` + """ + + @_default_hook("fromager.packagesettings", "default_update_extra_environ") + def update_extra_environ( + self, + *, + ctx: context.WorkContext, + req: Requirement, + version: Version | None, + sdist_root_dir: pathlib.Path, + extra_environ: dict[str, str], + build_env: build_environment.BuildEnvironment, + ) -> None: + """Update extra_environ dict in-place with additional environment variables. + + Default: :func:`fromager.packagesettings.default_update_extra_environ` + """ diff --git a/tests/test_overrides.py b/tests/test_overrides.py index 7405ff18..019f2e59 100644 --- a/tests/test_overrides.py +++ b/tests/test_overrides.py @@ -41,3 +41,52 @@ def default_foo(arg1: typing.Any) -> bool: assert overrides.find_and_invoke( "pkg", "foo", default_foo, arg1="value1", arg2="value2" ) + + +def test_list_hooks() -> None: + hooks = overrides.OverrideHookProtocol.list_hooks() + assert isinstance(hooks, list) + assert len(hooks) == 13 + + +def test_get_default_unknown_hook() -> None: + with pytest.raises(KeyError): + overrides.OverrideHookProtocol.get_default("no_such_hook") + + +def test_check_signature_matching() -> None: + def build_wheel( + ctx: typing.Any, + build_env: typing.Any, + extra_environ: typing.Any, + req: typing.Any, + sdist_root_dir: typing.Any, + version: typing.Any, + build_dir: typing.Any, + ) -> None: + pass + + overrides.OverrideHookProtocol.check_signature(build_wheel) + + +def test_check_signature_unknown_hook() -> None: + def no_such_hook() -> None: + pass + + with pytest.raises(KeyError): + overrides.OverrideHookProtocol.check_signature(no_such_hook) + + +def test_check_signature_args_mismatch() -> None: + def build_wheel(ctx: typing.Any) -> None: + pass + + with pytest.raises(TypeError, match="argument names mismatch"): + overrides.OverrideHookProtocol.check_signature(build_wheel) + + +@pytest.mark.parametrize("hook_name", overrides.OverrideHookProtocol.list_hooks()) +def test_protocol_signature_matches_default(hook_name: str) -> None: + default_fn = overrides.OverrideHookProtocol.get_default(hook_name) + assert callable(default_fn) + overrides.OverrideHookProtocol.check_signature(default_fn, hook_name=hook_name)