From 76b77a8aefeacaa041a64ae1cba8d9176cdd249b Mon Sep 17 00:00:00 2001 From: Artur Shiriev Date: Sat, 2 May 2026 19:03:51 +0300 Subject: [PATCH] Unify validation and remove Container.validate_provider Fold the per-provider validation path into a single Container.validate() that walks the registry via get_dependencies(), detecting cycles, missing kwargs, and missing alias sources in one pass. Drop the parallel provider.validate() machinery whose dict return value was never consumed. Co-Authored-By: Claude Opus 4.7 --- docs/dev/key-concepts.md | 3 +- modern_di/container.py | 3 -- modern_di/providers/abstract.py | 3 -- modern_di/providers/alias.py | 9 ----- modern_di/providers/container_provider.py | 3 -- modern_di/providers/context_provider.py | 3 -- modern_di/providers/factory.py | 14 -------- skills/modern-di/testing.md | 4 +-- tests/providers/test_alias.py | 3 +- tests/providers/test_container_provider.py | 1 - tests/providers/test_context_provider.py | 3 +- tests/providers/test_factory.py | 9 +++-- tests/test_container.py | 42 ++++++++++++++++++++++ 13 files changed, 51 insertions(+), 49 deletions(-) diff --git a/docs/dev/key-concepts.md b/docs/dev/key-concepts.md index e1f75f3..68c9bf1 100644 --- a/docs/dev/key-concepts.md +++ b/docs/dev/key-concepts.md @@ -65,8 +65,7 @@ Container provides methods for resolving dependencies: 1. `resolve_provider(provider)` - Resolve a specific provider instance 2. `resolve(SomeType)` - Resolve by type -3. `validate_provider(provider)` - Validate that the provider's dependency graph is wired correctly without creating real instances (useful at startup) -4. `validate()` - Walk the entire provider graph and detect circular dependencies (raises `RuntimeError` with the cycle path if found) +3. `validate()` - Walk the entire provider graph and detect missing dependencies, missing alias sources, and circular dependencies — without creating real instances (useful at startup) You can also pass `validate=True` to `Container(...)` to run validation automatically at creation time: diff --git a/modern_di/container.py b/modern_di/container.py index 50512bf..010aae3 100644 --- a/modern_di/container.py +++ b/modern_di/container.py @@ -102,9 +102,6 @@ def resolve_provider(self, provider: "AbstractProvider[types.T]") -> types.T: return provider.resolve(self) - def validate_provider(self, provider: "AbstractProvider[types.T]") -> types.T: - return typing.cast(types.T, provider.validate(self)) - def validate(self) -> None: visiting: set[int] = set() visited: set[int] = set() diff --git a/modern_di/providers/abstract.py b/modern_di/providers/abstract.py index aee37bb..2e7265b 100644 --- a/modern_di/providers/abstract.py +++ b/modern_di/providers/abstract.py @@ -28,8 +28,5 @@ def __init__( @abc.abstractmethod def resolve(self, container: "Container") -> typing.Any: ... # noqa: ANN401 - @abc.abstractmethod - def validate(self, container: "Container") -> dict[str, typing.Any]: ... - def get_dependencies(self, container: "Container") -> dict[str, "AbstractProvider[typing.Any]"]: # noqa: ARG002 return {} diff --git a/modern_di/providers/alias.py b/modern_di/providers/alias.py index ed4b61c..d53d23a 100644 --- a/modern_di/providers/alias.py +++ b/modern_di/providers/alias.py @@ -34,14 +34,5 @@ def _find_source(self, container: "Container") -> "AbstractProvider[types.T_co]" def get_dependencies(self, container: "Container") -> dict[str, "AbstractProvider[typing.Any]"]: return {"source": self._find_source(container)} - def validate(self, container: "Container") -> dict[str, typing.Any]: - source = self._find_source(container) - return { - "bound_type": self.bound_type, - "source_type": self._source_type, - "source": source.validate(container), - "self": self, - } - def resolve(self, container: "Container") -> types.T_co: return container.resolve_provider(self._find_source(container)) diff --git a/modern_di/providers/container_provider.py b/modern_di/providers/container_provider.py index a500441..3abe243 100644 --- a/modern_di/providers/container_provider.py +++ b/modern_di/providers/container_provider.py @@ -17,8 +17,5 @@ def __init__(self) -> None: def resolve(self, container: "Container") -> "Container": return container - def validate(self, container: "Container") -> dict[str, typing.Any]: # noqa: ARG002 - return {"self": self} - container_provider = _ContainerProvider() diff --git a/modern_di/providers/context_provider.py b/modern_di/providers/context_provider.py index 0b1ed47..4449a8a 100644 --- a/modern_di/providers/context_provider.py +++ b/modern_di/providers/context_provider.py @@ -25,9 +25,6 @@ def __init__( def __repr__(self) -> str: return f"ContextProvider(context_type={self._context_type!r}, scope={self.scope!r})" - def validate(self, container: "Container") -> dict[str, typing.Any]: # noqa: ARG002 - return {"bound_type": self.bound_type, "self": self} - def resolve(self, container: "Container") -> types.T_co | None: container = container.find_container(self.scope) return container.context_registry.find_context(self._context_type) diff --git a/modern_di/providers/factory.py b/modern_di/providers/factory.py index 0a1e3ac..419fa94 100644 --- a/modern_di/providers/factory.py +++ b/modern_di/providers/factory.py @@ -123,20 +123,6 @@ def get_dependencies(self, container: "Container") -> dict[str, "AbstractProvide provider_kwargs, _ = self._ensure_kwargs_cached(scoped_container, cache_item) return provider_kwargs - def validate(self, container: "Container") -> dict[str, typing.Any]: - container = container.find_container(self.scope) - cache_item = container.cache_registry.fetch_cache_item(self) - provider_kwargs, static_kwargs = self._ensure_kwargs_cached(container, cache_item) - validated_kwargs: dict[str, typing.Any] = {k: v.validate(container) for k, v in provider_kwargs.items()} - validated_kwargs.update(static_kwargs) - return { - "bound_type": self.bound_type, - "creator": self._creator, - "self": self, - "kwargs": validated_kwargs, - "cache_settings": self.cache_settings, - } - def resolve(self, container: "Container") -> types.T_co: container = container.find_container(self.scope) cache_item = container.cache_registry.fetch_cache_item(self) diff --git a/skills/modern-di/testing.md b/skills/modern-di/testing.md index 280d383..e59e8e4 100644 --- a/skills/modern-di/testing.md +++ b/skills/modern-di/testing.md @@ -116,6 +116,6 @@ Check that all dependencies can be resolved without creating real objects: ```python def test_wiring(): container = Container(scope=Scope.APP, groups=[Dependencies]) - container.validate_provider(Dependencies.users_repository) - # Raises RuntimeError if any dependency is missing + container.validate() + # Raises if any provider has missing dependencies, missing alias sources, or circular deps ``` diff --git a/tests/providers/test_alias.py b/tests/providers/test_alias.py index 58d46c5..3da6644 100644 --- a/tests/providers/test_alias.py +++ b/tests/providers/test_alias.py @@ -33,7 +33,6 @@ def test_alias_delegates_to_source() -> None: def test_alias_resolve_provider() -> None: container = Container(groups=[MyGroup]) - container.validate_provider(MyGroup.abstract_repo) instance = container.resolve_provider(MyGroup.abstract_repo) assert isinstance(instance, PostgresRepository) @@ -99,7 +98,7 @@ class G(Group): container = Container(groups=[G]) with pytest.raises(AliasSourceNotRegisteredError, match="PostgresRepository"): - container.validate_provider(G.abstract) + container.resolve_provider(G.abstract) def test_alias_missing_source_raises_on_container_validate() -> None: diff --git a/tests/providers/test_container_provider.py b/tests/providers/test_container_provider.py index 57521ea..27169d7 100644 --- a/tests/providers/test_container_provider.py +++ b/tests/providers/test_container_provider.py @@ -7,7 +7,6 @@ def test_container_provider_direct_resolving() -> None: request_container = app_container.build_child_container(scope=Scope.REQUEST) assert request_container.resolve_provider(providers.container_provider) is request_container - request_container.validate_provider(providers.container_provider) def test_container_provider_sub_dependency() -> None: diff --git a/tests/providers/test_context_provider.py b/tests/providers/test_context_provider.py index bc665c4..4f652a2 100644 --- a/tests/providers/test_context_provider.py +++ b/tests/providers/test_context_provider.py @@ -22,8 +22,7 @@ class MyGroup(Group): def test_context_provider() -> None: now = datetime.datetime.now(tz=datetime.timezone.utc) - app_container = Container(context={datetime.datetime: now}) - app_container.validate_provider(MyGroup.context_provider) + app_container = Container(groups=[MyGroup], context={datetime.datetime: now}) instance1 = app_container.resolve_provider(MyGroup.context_provider) instance2 = app_container.resolve_provider(MyGroup.context_provider) assert instance1 is instance2 is now diff --git a/tests/providers/test_factory.py b/tests/providers/test_factory.py index 0e63602..44e1e31 100644 --- a/tests/providers/test_factory.py +++ b/tests/providers/test_factory.py @@ -46,7 +46,6 @@ class MyGroup(Group): def test_app_factory() -> None: app_container = Container(groups=[MyGroup]) instance1 = app_container.resolve_provider(MyGroup.app_factory) - app_container.validate_provider(MyGroup.app_factory) instance2 = app_container.resolve(dependency_type=SimpleCreator) assert isinstance(instance1, SimpleCreator) assert isinstance(instance2, SimpleCreator) @@ -64,7 +63,7 @@ def test_app_factory_skip_creator_parsing() -> None: def test_app_factory_unresolvable() -> None: app_container = Container(groups=[MyGroup]) with pytest.raises(ArgumentResolutionError, match="Argument dep1 of type cannot be resolved") as exc: - app_container.validate_provider(MyGroup.app_factory_unresolvable) + app_container.resolve_provider(MyGroup.app_factory_unresolvable) assert exc.value.arg_name == "dep1" assert exc.value.arg_type is str @@ -78,16 +77,16 @@ def test_func_with_union_factory() -> None: def test_func_with_broken_annotation() -> None: app_container = Container(groups=[MyGroup]) with pytest.raises(ArgumentResolutionError, match="Argument dep1 of type None cannot be resolved"): - app_container.validate_provider(MyGroup.func_with_broken_annotation) + app_container.resolve_provider(MyGroup.func_with_broken_annotation) def test_request_factory() -> None: app_container = Container(groups=[MyGroup]) request_container = app_container.build_child_container(scope=Scope.REQUEST) - request_container.validate_provider(MyGroup.request_factory) + request_container.resolve_provider(MyGroup.request_factory) instance1 = request_container.resolve_provider(MyGroup.request_factory) instance2 = request_container.resolve_provider(MyGroup.request_factory) - request_container.validate_provider(MyGroup.request_factory) + request_container.resolve_provider(MyGroup.request_factory) assert instance1 is not instance2 request_container = app_container.build_child_container(scope=Scope.REQUEST) diff --git a/tests/test_container.py b/tests/test_container.py index 5b56fa7..67b4713 100644 --- a/tests/test_container.py +++ b/tests/test_container.py @@ -1,5 +1,6 @@ import copy import dataclasses +import typing import pytest @@ -11,6 +12,7 @@ ProviderNotRegisteredError, ScopeSkippedError, ) +from modern_di.providers.abstract import AbstractProvider def test_container_prevent_copy() -> None: @@ -125,3 +127,43 @@ class ValidGroup(Group): container = Container(groups=[ValidGroup]) container.validate() # should not raise + + +def test_validate_memoizes_diamond() -> None: + @dataclasses.dataclass(kw_only=True, slots=True) + class Bottom: + pass + + @dataclasses.dataclass(kw_only=True, slots=True) + class Left: + bottom: Bottom + + @dataclasses.dataclass(kw_only=True, slots=True) + class Right: + bottom: Bottom + + @dataclasses.dataclass(kw_only=True, slots=True) + class Top: + left: Left + right: Right + + bottom_provider = providers.Factory(creator=Bottom) + call_count = 0 + original_get_dependencies = bottom_provider.get_dependencies + + def counting_get_dependencies(container: Container) -> dict[str, AbstractProvider[typing.Any]]: + nonlocal call_count + call_count += 1 + return original_get_dependencies(container) + + bottom_provider.get_dependencies = counting_get_dependencies # ty: ignore[invalid-assignment] + + class DiamondGroup(Group): + bottom = bottom_provider + left = providers.Factory(creator=Left) + right = providers.Factory(creator=Right) + top = providers.Factory(creator=Top) + + container = Container(groups=[DiamondGroup]) + container.validate() + assert call_count == 1