Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/fromager/commands/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,21 @@ def lint(
errors += 1
logger.error(f"ERROR: plugin name {name} should be {expected_name}")

logger.info("Checking override hook signatures...")
hook_names = overrides.OverrideHookProtocol.list_hooks()
for ext in exts:
mod = ext.plugin
for hook_name in hook_names:
func = getattr(mod, hook_name, None)
if func is None:
continue
try:
overrides.OverrideHookProtocol.check_signature(
func, hook_name=hook_name
)
except TypeError as e:
errors += 1
logger.error(f"ERROR: override {ext.name}.{hook_name}: {e}")

if errors:
raise SystemExit(f"Found {errors} errors")
272 changes: 272 additions & 0 deletions src/fromager/overrides.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way to do this would be to have fromager define a type so we could pass fewer individual arguments to hooks. For example, the build hooks both take the same arguments. We could define a type and pass 1 value to the hook. Then when we add fields, the hooks don't have to change if they don't need the value.

Maybe the methods that manipulate the source (download, prepare, etc.) could be standardized to take a single type, too?

How many different types would we need?

Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations

import importlib
import inspect
import logging
import pathlib
Expand All @@ -6,8 +9,14 @@

from packaging.requirements import Requirement
from packaging.utils import canonicalize_name
from packaging.version import Version
from stevedore import extension

if typing.TYPE_CHECKING:
from . import build_environment, context
from .requirements_file import RequirementType
from .resolver import BaseProvider

# An interface for reretrieving per-package information which influences
# the build process for a particular package - i.e. for a given package
# and build target, what patches should we apply, what environment variables
Expand Down Expand Up @@ -134,3 +143,266 @@ def find_override_method(distname: str, method: str) -> typing.Callable | None:
return None
logger.info("%s: found %s override", distname, method)
return typing.cast(typing.Callable, getattr(mod, method))


_F = typing.TypeVar("_F", bound=typing.Callable[..., typing.Any])


def _default_hook(module: str, func: str) -> typing.Callable[[_F], _F]:
"""Decorator that annotates a Protocol method with its default implementation.

Stores a ``fromager_default`` attribute as a ``(module, func)`` tuple
on the decorated function so the mapping from hook name to default can
be discovered at runtime.
"""

def decorator(fn: _F) -> _F:
fn.fromager_default = (module, func) # type: ignore[attr-defined]
return fn

return decorator


class OverrideHookProtocol(typing.Protocol):
"""Protocol defining the interface for per-package override hooks.

Override modules may implement any subset of these methods to customize
the build process for a specific package. See the default implementations
for each hook's behavior when no override is provided.
"""

@classmethod
def list_hooks(cls) -> list[str]:
"""Return a list of hook names defined on this Protocol."""
return [
name for name, obj in vars(cls).items() if hasattr(obj, "fromager_default")
]

@classmethod
def get_default(cls, hook_name: str) -> typing.Callable[..., typing.Any]:
"""Return the default function object for a hook name."""
obj = getattr(cls, hook_name, None)
if obj is None or not hasattr(obj, "fromager_default"):
raise KeyError(hook_name)
module_name, func_name = obj.fromager_default
module = importlib.import_module(module_name)
return typing.cast(typing.Callable[..., typing.Any], getattr(module, func_name))
Comment on lines +181 to +189
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

get_default silently propagates AttributeError with no context.

If func_name doesn't exist in the imported module (e.g., typo in _default_hook args), getattr(module, func_name) raises a bare AttributeError with no mention of which hook is broken.

🛡️ Proposed fix
     module = importlib.import_module(module_name)
-    return typing.cast(typing.Callable[..., typing.Any], getattr(module, func_name))
+    func = getattr(module, func_name, None)
+    if func is None:
+        raise AttributeError(
+            f"Hook {hook_name!r}: no attribute {func_name!r} in {module_name!r}"
+        ) from None
+    return typing.cast(typing.Callable[..., typing.Any], func)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/fromager/overrides.py` around lines 181 - 189, get_default currently lets
a raw AttributeError bubble up if the default function named by
obj.fromager_default is missing; update get_default (the classmethod named
get_default and the use of obj.fromager_default -> module_name, func_name) to
catch AttributeError around getattr(module, func_name) and raise a KeyError
referencing the hook_name and the missing func_name/module_name (include the
original exception as the __cause__ with "from") so callers get clear context
about which hook/default is broken.


@classmethod
def check_signature(
cls,
func: typing.Callable[..., typing.Any],
*,
hook_name: str | None = None,
) -> None:
"""Check that a function's argument names match the protocol method.

Only argument names are compared; the check ignores whether arguments
are positional or keyword-only because all hooks are called with
keyword arguments.
"""
if hook_name is None:
hook_name = func.__name__
proto_method = getattr(cls, hook_name, None)
if proto_method is None or not hasattr(proto_method, "fromager_default"):
raise KeyError(hook_name)
proto_spec = inspect.getfullargspec(proto_method)
# Skip 'self' (first parameter of a protocol method)
expected_args = set(proto_spec.args[1:] + proto_spec.kwonlyargs)
func_spec = inspect.getfullargspec(func)
func_args = set(func_spec.args + func_spec.kwonlyargs)
if expected_args != func_args:
raise TypeError(
f"{hook_name}: argument names mismatch: "
f"expected {sorted(expected_args)}, got {sorted(func_args)}"
)

@_default_hook("fromager.wheels", "default_add_extra_metadata_to_wheels")
def add_extra_metadata_to_wheels(
self,
ctx: context.WorkContext,
req: Requirement,
version: Version,
extra_environ: dict[str, str],
sdist_root_dir: pathlib.Path,
dist_info_dir: pathlib.Path,
) -> dict[str, typing.Any]:
"""Add extra metadata files to built wheels.

Default: :func:`fromager.wheels.default_add_extra_metadata_to_wheels`
"""

@_default_hook("fromager.sources", "default_build_sdist")
def build_sdist(
self,
ctx: context.WorkContext,
extra_environ: dict,
req: Requirement,
version: Version,
sdist_root_dir: pathlib.Path,
build_env: build_environment.BuildEnvironment,
build_dir: pathlib.Path,
) -> pathlib.Path:
"""Build an sdist from the prepared source tree.

Default: :func:`fromager.sources.default_build_sdist`
"""

@_default_hook("fromager.wheels", "default_build_wheel")
def build_wheel(
self,
ctx: context.WorkContext,
build_env: build_environment.BuildEnvironment,
extra_environ: dict[str, str],
req: Requirement,
sdist_root_dir: pathlib.Path,
version: Version,
build_dir: pathlib.Path,
) -> pathlib.Path:
"""Build a wheel from the prepared source tree.

Default: :func:`fromager.wheels.default_build_wheel`
"""

@_default_hook("fromager.sources", "default_download_source")
def download_source(
self,
ctx: context.WorkContext,
req: Requirement,
version: Version,
download_url: str,
sdists_downloads_dir: pathlib.Path,
) -> pathlib.Path:
"""Download the source archive for a requirement.

Default: :func:`fromager.sources.default_download_source`
"""

@_default_hook("fromager.finders", "default_expected_source_archive_name")
def expected_source_archive_name(
self,
ctx: context.WorkContext,
req: Requirement,
dist_version: str,
) -> str | None:
"""Return the expected filename for a source archive.

Default: :func:`fromager.finders.default_expected_source_archive_name`
"""

@_default_hook("fromager.finders", "default_expected_source_directory_name")
def expected_source_directory_name(
self,
req: Requirement,
dist_version: str,
) -> str:
"""Return the expected directory name after unpacking a source archive.

Default: :func:`fromager.finders.default_expected_source_directory_name`
"""

@_default_hook("fromager.dependencies", "default_get_build_backend_dependencies")
def get_build_backend_dependencies(
self,
ctx: context.WorkContext,
req: Requirement,
sdist_root_dir: pathlib.Path,
build_dir: pathlib.Path,
extra_environ: dict[str, str],
build_env: build_environment.BuildEnvironment,
) -> typing.Iterable[str]:
"""Get build backend dependencies (PEP 517 get_requires_for_build_wheel).

Default: :func:`fromager.dependencies.default_get_build_backend_dependencies`
"""

@_default_hook("fromager.dependencies", "default_get_build_sdist_dependencies")
def get_build_sdist_dependencies(
self,
ctx: context.WorkContext,
req: Requirement,
sdist_root_dir: pathlib.Path,
build_dir: pathlib.Path,
extra_environ: dict[str, str],
build_env: build_environment.BuildEnvironment,
) -> typing.Iterable[str]:
"""Get build sdist dependencies.

Default: :func:`fromager.dependencies.default_get_build_sdist_dependencies`
"""

@_default_hook("fromager.dependencies", "default_get_build_system_dependencies")
def get_build_system_dependencies(
self,
ctx: context.WorkContext,
req: Requirement,
sdist_root_dir: pathlib.Path,
build_dir: pathlib.Path,
) -> typing.Iterable[str]:
"""Get build system dependencies from pyproject.toml [build-system] requires.

Default: :func:`fromager.dependencies.default_get_build_system_dependencies`
"""

@_default_hook("fromager.dependencies", "default_get_install_dependencies_of_sdist")
def get_install_dependencies_of_sdist(
self,
*,
ctx: context.WorkContext,
req: Requirement,
version: Version,
sdist_root_dir: pathlib.Path,
build_env: build_environment.BuildEnvironment,
extra_environ: dict[str, str],
build_dir: pathlib.Path,
config_settings: dict[str, str],
) -> set[Requirement]:
"""Get install dependencies (Requires-Dist) from source distribution.

Default: :func:`fromager.dependencies.default_get_install_dependencies_of_sdist`
"""

@_default_hook("fromager.resolver", "default_resolver_provider")
def get_resolver_provider(
self,
ctx: context.WorkContext,
req: Requirement,
sdist_server_url: str,
include_sdists: bool,
include_wheels: bool,
req_type: RequirementType | None = None,
ignore_platform: bool = False,
) -> BaseProvider:
"""Return a resolver provider for resolving package versions.

Default: :func:`fromager.resolver.default_resolver_provider`
"""

@_default_hook("fromager.sources", "default_prepare_source")
def prepare_source(
self,
ctx: context.WorkContext,
req: Requirement,
source_filename: pathlib.Path,
version: Version,
) -> tuple[pathlib.Path, bool]:
"""Unpack, modify, and prepare source for building.

Default: :func:`fromager.sources.default_prepare_source`
"""

@_default_hook("fromager.packagesettings", "default_update_extra_environ")
def update_extra_environ(
self,
*,
ctx: context.WorkContext,
req: Requirement,
version: Version | None,
sdist_root_dir: pathlib.Path,
extra_environ: dict[str, str],
build_env: build_environment.BuildEnvironment,
) -> None:
"""Update extra_environ dict in-place with additional environment variables.

Default: :func:`fromager.packagesettings.default_update_extra_environ`
"""
49 changes: 49 additions & 0 deletions tests/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,52 @@ def default_foo(arg1: typing.Any) -> bool:
assert overrides.find_and_invoke(
"pkg", "foo", default_foo, arg1="value1", arg2="value2"
)


def test_list_hooks() -> None:
hooks = overrides.OverrideHookProtocol.list_hooks()
assert isinstance(hooks, list)
assert len(hooks) == 13
Comment on lines +46 to +49
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Hardcoded hook count breaks silently when hooks are added or removed.

assert len(hooks) == 13 will fail on every hook addition/removal without communicating which hook was added or removed. The parametrized test_protocol_signature_matches_default already exercises every hook end-to-end; this test adds little value in its current form.

♻️ Proposed fix
 def test_list_hooks() -> None:
     hooks = overrides.OverrideHookProtocol.list_hooks()
     assert isinstance(hooks, list)
-    assert len(hooks) == 13
+    # Spot-check a stable subset; the parametrized test validates all hooks.
+    assert "build_wheel" in hooks
+    assert "download_source" in hooks
+    assert len(hooks) > 0
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_overrides.py` around lines 46 - 49, Remove the fragile hardcoded
count check in test_list_hooks: in the function test_list_hooks (which calls
overrides.OverrideHookProtocol.list_hooks) delete the assertion assert
len(hooks) == 13 and instead assert that hooks is non-empty (e.g., assert hooks)
and that there are no duplicates (e.g., assert len(set(hooks)) == len(hooks));
keep the isinstance(hooks, list) check and rely on the existing parametrized
test_protocol_signature_matches_default for per-hook behavior coverage.



def test_get_default_unknown_hook() -> None:
with pytest.raises(KeyError):
overrides.OverrideHookProtocol.get_default("no_such_hook")


def test_check_signature_matching() -> None:
def build_wheel(
ctx: typing.Any,
build_env: typing.Any,
extra_environ: typing.Any,
req: typing.Any,
sdist_root_dir: typing.Any,
version: typing.Any,
build_dir: typing.Any,
) -> None:
pass

overrides.OverrideHookProtocol.check_signature(build_wheel)


def test_check_signature_unknown_hook() -> None:
def no_such_hook() -> None:
pass

with pytest.raises(KeyError):
overrides.OverrideHookProtocol.check_signature(no_such_hook)


def test_check_signature_args_mismatch() -> None:
def build_wheel(ctx: typing.Any) -> None:
pass

with pytest.raises(TypeError, match="argument names mismatch"):
overrides.OverrideHookProtocol.check_signature(build_wheel)


@pytest.mark.parametrize("hook_name", overrides.OverrideHookProtocol.list_hooks())
def test_protocol_signature_matches_default(hook_name: str) -> None:
default_fn = overrides.OverrideHookProtocol.get_default(hook_name)
assert callable(default_fn)
overrides.OverrideHookProtocol.check_signature(default_fn, hook_name=hook_name)
Loading