-
Notifications
You must be signed in to change notification settings - Fork 50
feat: add protocol for override hooks #1126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,6 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import importlib | ||
| import inspect | ||
| import logging | ||
| import pathlib | ||
|
|
@@ -6,8 +9,14 @@ | |
|
|
||
| from packaging.requirements import Requirement | ||
| from packaging.utils import canonicalize_name | ||
| from packaging.version import Version | ||
| from stevedore import extension | ||
|
|
||
| if typing.TYPE_CHECKING: | ||
| from . import build_environment, context | ||
| from .requirements_file import RequirementType | ||
| from .resolver import BaseProvider | ||
|
|
||
| # An interface for reretrieving per-package information which influences | ||
| # the build process for a particular package - i.e. for a given package | ||
| # and build target, what patches should we apply, what environment variables | ||
|
|
@@ -134,3 +143,266 @@ def find_override_method(distname: str, method: str) -> typing.Callable | None: | |
| return None | ||
| logger.info("%s: found %s override", distname, method) | ||
| return typing.cast(typing.Callable, getattr(mod, method)) | ||
|
|
||
|
|
||
| _F = typing.TypeVar("_F", bound=typing.Callable[..., typing.Any]) | ||
|
|
||
|
|
||
| def _default_hook(module: str, func: str) -> typing.Callable[[_F], _F]: | ||
| """Decorator that annotates a Protocol method with its default implementation. | ||
|
|
||
| Stores a ``fromager_default`` attribute as a ``(module, func)`` tuple | ||
| on the decorated function so the mapping from hook name to default can | ||
| be discovered at runtime. | ||
| """ | ||
|
|
||
| def decorator(fn: _F) -> _F: | ||
| fn.fromager_default = (module, func) # type: ignore[attr-defined] | ||
| return fn | ||
|
|
||
| return decorator | ||
|
|
||
|
|
||
| class OverrideHookProtocol(typing.Protocol): | ||
| """Protocol defining the interface for per-package override hooks. | ||
|
|
||
| Override modules may implement any subset of these methods to customize | ||
| the build process for a specific package. See the default implementations | ||
| for each hook's behavior when no override is provided. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def list_hooks(cls) -> list[str]: | ||
| """Return a list of hook names defined on this Protocol.""" | ||
| return [ | ||
| name for name, obj in vars(cls).items() if hasattr(obj, "fromager_default") | ||
| ] | ||
|
|
||
| @classmethod | ||
| def get_default(cls, hook_name: str) -> typing.Callable[..., typing.Any]: | ||
| """Return the default function object for a hook name.""" | ||
| obj = getattr(cls, hook_name, None) | ||
| if obj is None or not hasattr(obj, "fromager_default"): | ||
| raise KeyError(hook_name) | ||
| module_name, func_name = obj.fromager_default | ||
| module = importlib.import_module(module_name) | ||
| return typing.cast(typing.Callable[..., typing.Any], getattr(module, func_name)) | ||
|
Comment on lines
+181
to
+189
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If 🛡️ Proposed fix module = importlib.import_module(module_name)
- return typing.cast(typing.Callable[..., typing.Any], getattr(module, func_name))
+ func = getattr(module, func_name, None)
+ if func is None:
+ raise AttributeError(
+ f"Hook {hook_name!r}: no attribute {func_name!r} in {module_name!r}"
+ ) from None
+ return typing.cast(typing.Callable[..., typing.Any], func)🤖 Prompt for AI Agents |
||
|
|
||
| @classmethod | ||
| def check_signature( | ||
| cls, | ||
| func: typing.Callable[..., typing.Any], | ||
| *, | ||
| hook_name: str | None = None, | ||
| ) -> None: | ||
| """Check that a function's argument names match the protocol method. | ||
|
|
||
| Only argument names are compared; the check ignores whether arguments | ||
| are positional or keyword-only because all hooks are called with | ||
| keyword arguments. | ||
| """ | ||
| if hook_name is None: | ||
| hook_name = func.__name__ | ||
| proto_method = getattr(cls, hook_name, None) | ||
| if proto_method is None or not hasattr(proto_method, "fromager_default"): | ||
| raise KeyError(hook_name) | ||
| proto_spec = inspect.getfullargspec(proto_method) | ||
| # Skip 'self' (first parameter of a protocol method) | ||
| expected_args = set(proto_spec.args[1:] + proto_spec.kwonlyargs) | ||
| func_spec = inspect.getfullargspec(func) | ||
| func_args = set(func_spec.args + func_spec.kwonlyargs) | ||
| if expected_args != func_args: | ||
| raise TypeError( | ||
| f"{hook_name}: argument names mismatch: " | ||
| f"expected {sorted(expected_args)}, got {sorted(func_args)}" | ||
| ) | ||
|
|
||
| @_default_hook("fromager.wheels", "default_add_extra_metadata_to_wheels") | ||
| def add_extra_metadata_to_wheels( | ||
| self, | ||
| ctx: context.WorkContext, | ||
| req: Requirement, | ||
| version: Version, | ||
| extra_environ: dict[str, str], | ||
| sdist_root_dir: pathlib.Path, | ||
| dist_info_dir: pathlib.Path, | ||
| ) -> dict[str, typing.Any]: | ||
| """Add extra metadata files to built wheels. | ||
|
|
||
| Default: :func:`fromager.wheels.default_add_extra_metadata_to_wheels` | ||
| """ | ||
|
|
||
| @_default_hook("fromager.sources", "default_build_sdist") | ||
| def build_sdist( | ||
| self, | ||
| ctx: context.WorkContext, | ||
| extra_environ: dict, | ||
| req: Requirement, | ||
| version: Version, | ||
| sdist_root_dir: pathlib.Path, | ||
| build_env: build_environment.BuildEnvironment, | ||
| build_dir: pathlib.Path, | ||
| ) -> pathlib.Path: | ||
| """Build an sdist from the prepared source tree. | ||
|
|
||
| Default: :func:`fromager.sources.default_build_sdist` | ||
| """ | ||
|
|
||
| @_default_hook("fromager.wheels", "default_build_wheel") | ||
| def build_wheel( | ||
| self, | ||
| ctx: context.WorkContext, | ||
| build_env: build_environment.BuildEnvironment, | ||
| extra_environ: dict[str, str], | ||
| req: Requirement, | ||
| sdist_root_dir: pathlib.Path, | ||
| version: Version, | ||
| build_dir: pathlib.Path, | ||
| ) -> pathlib.Path: | ||
| """Build a wheel from the prepared source tree. | ||
|
|
||
| Default: :func:`fromager.wheels.default_build_wheel` | ||
| """ | ||
|
|
||
| @_default_hook("fromager.sources", "default_download_source") | ||
| def download_source( | ||
| self, | ||
| ctx: context.WorkContext, | ||
| req: Requirement, | ||
| version: Version, | ||
| download_url: str, | ||
| sdists_downloads_dir: pathlib.Path, | ||
| ) -> pathlib.Path: | ||
| """Download the source archive for a requirement. | ||
|
|
||
| Default: :func:`fromager.sources.default_download_source` | ||
| """ | ||
|
|
||
| @_default_hook("fromager.finders", "default_expected_source_archive_name") | ||
| def expected_source_archive_name( | ||
| self, | ||
| ctx: context.WorkContext, | ||
| req: Requirement, | ||
| dist_version: str, | ||
| ) -> str | None: | ||
| """Return the expected filename for a source archive. | ||
|
|
||
| Default: :func:`fromager.finders.default_expected_source_archive_name` | ||
| """ | ||
|
|
||
| @_default_hook("fromager.finders", "default_expected_source_directory_name") | ||
| def expected_source_directory_name( | ||
| self, | ||
| req: Requirement, | ||
| dist_version: str, | ||
| ) -> str: | ||
| """Return the expected directory name after unpacking a source archive. | ||
|
|
||
| Default: :func:`fromager.finders.default_expected_source_directory_name` | ||
| """ | ||
|
|
||
| @_default_hook("fromager.dependencies", "default_get_build_backend_dependencies") | ||
| def get_build_backend_dependencies( | ||
| self, | ||
| ctx: context.WorkContext, | ||
| req: Requirement, | ||
| sdist_root_dir: pathlib.Path, | ||
| build_dir: pathlib.Path, | ||
| extra_environ: dict[str, str], | ||
| build_env: build_environment.BuildEnvironment, | ||
| ) -> typing.Iterable[str]: | ||
| """Get build backend dependencies (PEP 517 get_requires_for_build_wheel). | ||
|
|
||
| Default: :func:`fromager.dependencies.default_get_build_backend_dependencies` | ||
| """ | ||
|
|
||
| @_default_hook("fromager.dependencies", "default_get_build_sdist_dependencies") | ||
| def get_build_sdist_dependencies( | ||
| self, | ||
| ctx: context.WorkContext, | ||
| req: Requirement, | ||
| sdist_root_dir: pathlib.Path, | ||
| build_dir: pathlib.Path, | ||
| extra_environ: dict[str, str], | ||
| build_env: build_environment.BuildEnvironment, | ||
| ) -> typing.Iterable[str]: | ||
| """Get build sdist dependencies. | ||
|
|
||
| Default: :func:`fromager.dependencies.default_get_build_sdist_dependencies` | ||
| """ | ||
|
|
||
| @_default_hook("fromager.dependencies", "default_get_build_system_dependencies") | ||
| def get_build_system_dependencies( | ||
| self, | ||
| ctx: context.WorkContext, | ||
| req: Requirement, | ||
| sdist_root_dir: pathlib.Path, | ||
| build_dir: pathlib.Path, | ||
| ) -> typing.Iterable[str]: | ||
| """Get build system dependencies from pyproject.toml [build-system] requires. | ||
|
|
||
| Default: :func:`fromager.dependencies.default_get_build_system_dependencies` | ||
| """ | ||
|
|
||
| @_default_hook("fromager.dependencies", "default_get_install_dependencies_of_sdist") | ||
| def get_install_dependencies_of_sdist( | ||
| self, | ||
| *, | ||
| ctx: context.WorkContext, | ||
| req: Requirement, | ||
| version: Version, | ||
| sdist_root_dir: pathlib.Path, | ||
| build_env: build_environment.BuildEnvironment, | ||
| extra_environ: dict[str, str], | ||
| build_dir: pathlib.Path, | ||
| config_settings: dict[str, str], | ||
| ) -> set[Requirement]: | ||
| """Get install dependencies (Requires-Dist) from source distribution. | ||
|
|
||
| Default: :func:`fromager.dependencies.default_get_install_dependencies_of_sdist` | ||
| """ | ||
|
|
||
| @_default_hook("fromager.resolver", "default_resolver_provider") | ||
| def get_resolver_provider( | ||
| self, | ||
| ctx: context.WorkContext, | ||
| req: Requirement, | ||
| sdist_server_url: str, | ||
| include_sdists: bool, | ||
| include_wheels: bool, | ||
| req_type: RequirementType | None = None, | ||
| ignore_platform: bool = False, | ||
| ) -> BaseProvider: | ||
| """Return a resolver provider for resolving package versions. | ||
|
|
||
| Default: :func:`fromager.resolver.default_resolver_provider` | ||
| """ | ||
|
|
||
| @_default_hook("fromager.sources", "default_prepare_source") | ||
| def prepare_source( | ||
| self, | ||
| ctx: context.WorkContext, | ||
| req: Requirement, | ||
| source_filename: pathlib.Path, | ||
| version: Version, | ||
| ) -> tuple[pathlib.Path, bool]: | ||
| """Unpack, modify, and prepare source for building. | ||
|
|
||
| Default: :func:`fromager.sources.default_prepare_source` | ||
| """ | ||
|
|
||
| @_default_hook("fromager.packagesettings", "default_update_extra_environ") | ||
| def update_extra_environ( | ||
| self, | ||
| *, | ||
| ctx: context.WorkContext, | ||
| req: Requirement, | ||
| version: Version | None, | ||
| sdist_root_dir: pathlib.Path, | ||
| extra_environ: dict[str, str], | ||
| build_env: build_environment.BuildEnvironment, | ||
| ) -> None: | ||
| """Update extra_environ dict in-place with additional environment variables. | ||
|
|
||
| Default: :func:`fromager.packagesettings.default_update_extra_environ` | ||
| """ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,3 +41,52 @@ def default_foo(arg1: typing.Any) -> bool: | |
| assert overrides.find_and_invoke( | ||
| "pkg", "foo", default_foo, arg1="value1", arg2="value2" | ||
| ) | ||
|
|
||
|
|
||
| def test_list_hooks() -> None: | ||
| hooks = overrides.OverrideHookProtocol.list_hooks() | ||
| assert isinstance(hooks, list) | ||
| assert len(hooks) == 13 | ||
|
Comment on lines
+46
to
+49
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hardcoded hook count breaks silently when hooks are added or removed.
♻️ Proposed fix def test_list_hooks() -> None:
hooks = overrides.OverrideHookProtocol.list_hooks()
assert isinstance(hooks, list)
- assert len(hooks) == 13
+ # Spot-check a stable subset; the parametrized test validates all hooks.
+ assert "build_wheel" in hooks
+ assert "download_source" in hooks
+ assert len(hooks) > 0🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| def test_get_default_unknown_hook() -> None: | ||
| with pytest.raises(KeyError): | ||
| overrides.OverrideHookProtocol.get_default("no_such_hook") | ||
|
|
||
|
|
||
| def test_check_signature_matching() -> None: | ||
| def build_wheel( | ||
| ctx: typing.Any, | ||
| build_env: typing.Any, | ||
| extra_environ: typing.Any, | ||
| req: typing.Any, | ||
| sdist_root_dir: typing.Any, | ||
| version: typing.Any, | ||
| build_dir: typing.Any, | ||
| ) -> None: | ||
| pass | ||
|
|
||
| overrides.OverrideHookProtocol.check_signature(build_wheel) | ||
|
|
||
|
|
||
| def test_check_signature_unknown_hook() -> None: | ||
| def no_such_hook() -> None: | ||
| pass | ||
|
|
||
| with pytest.raises(KeyError): | ||
| overrides.OverrideHookProtocol.check_signature(no_such_hook) | ||
|
|
||
|
|
||
| def test_check_signature_args_mismatch() -> None: | ||
| def build_wheel(ctx: typing.Any) -> None: | ||
| pass | ||
|
|
||
| with pytest.raises(TypeError, match="argument names mismatch"): | ||
| overrides.OverrideHookProtocol.check_signature(build_wheel) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("hook_name", overrides.OverrideHookProtocol.list_hooks()) | ||
| def test_protocol_signature_matches_default(hook_name: str) -> None: | ||
| default_fn = overrides.OverrideHookProtocol.get_default(hook_name) | ||
| assert callable(default_fn) | ||
| overrides.OverrideHookProtocol.check_signature(default_fn, hook_name=hook_name) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another way to do this would be to have fromager define a type so we could pass fewer individual arguments to hooks. For example, the build hooks both take the same arguments. We could define a type and pass 1 value to the hook. Then when we add fields, the hooks don't have to change if they don't need the value.
Maybe the methods that manipulate the source (download, prepare, etc.) could be standardized to take a single type, too?
How many different types would we need?