Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/dev/key-concepts.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ Container provides methods for resolving dependencies:

1. `resolve_provider(provider)` - Resolve a specific provider instance
2. `resolve(SomeType)` - Resolve by type
3. `validate_provider(provider)` - Validate that the provider's dependency graph is wired correctly without creating real instances (useful at startup)
4. `validate()` - Walk the entire provider graph and detect circular dependencies (raises `RuntimeError` with the cycle path if found)
3. `validate()` - Walk the entire provider graph and detect missing dependencies, missing alias sources, and circular dependencies — without creating real instances (useful at startup)

You can also pass `validate=True` to `Container(...)` to run validation automatically at creation time:

Expand Down
3 changes: 0 additions & 3 deletions modern_di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,6 @@ def resolve_provider(self, provider: "AbstractProvider[types.T]") -> types.T:

return provider.resolve(self)

def validate_provider(self, provider: "AbstractProvider[types.T]") -> types.T:
return typing.cast(types.T, provider.validate(self))

def validate(self) -> None:
visiting: set[int] = set()
visited: set[int] = set()
Expand Down
3 changes: 0 additions & 3 deletions modern_di/providers/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,5 @@ def __init__(
@abc.abstractmethod
def resolve(self, container: "Container") -> typing.Any: ... # noqa: ANN401

@abc.abstractmethod
def validate(self, container: "Container") -> dict[str, typing.Any]: ...

def get_dependencies(self, container: "Container") -> dict[str, "AbstractProvider[typing.Any]"]: # noqa: ARG002
return {}
9 changes: 0 additions & 9 deletions modern_di/providers/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,5 @@ def _find_source(self, container: "Container") -> "AbstractProvider[types.T_co]"
def get_dependencies(self, container: "Container") -> dict[str, "AbstractProvider[typing.Any]"]:
return {"source": self._find_source(container)}

def validate(self, container: "Container") -> dict[str, typing.Any]:
source = self._find_source(container)
return {
"bound_type": self.bound_type,
"source_type": self._source_type,
"source": source.validate(container),
"self": self,
}

def resolve(self, container: "Container") -> types.T_co:
return container.resolve_provider(self._find_source(container))
3 changes: 0 additions & 3 deletions modern_di/providers/container_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,5 @@ def __init__(self) -> None:
def resolve(self, container: "Container") -> "Container":
return container

def validate(self, container: "Container") -> dict[str, typing.Any]: # noqa: ARG002
return {"self": self}


container_provider = _ContainerProvider()
3 changes: 0 additions & 3 deletions modern_di/providers/context_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ def __init__(
def __repr__(self) -> str:
return f"ContextProvider(context_type={self._context_type!r}, scope={self.scope!r})"

def validate(self, container: "Container") -> dict[str, typing.Any]: # noqa: ARG002
return {"bound_type": self.bound_type, "self": self}

def resolve(self, container: "Container") -> types.T_co | None:
container = container.find_container(self.scope)
return container.context_registry.find_context(self._context_type)
14 changes: 0 additions & 14 deletions modern_di/providers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,20 +123,6 @@ def get_dependencies(self, container: "Container") -> dict[str, "AbstractProvide
provider_kwargs, _ = self._ensure_kwargs_cached(scoped_container, cache_item)
return provider_kwargs

def validate(self, container: "Container") -> dict[str, typing.Any]:
container = container.find_container(self.scope)
cache_item = container.cache_registry.fetch_cache_item(self)
provider_kwargs, static_kwargs = self._ensure_kwargs_cached(container, cache_item)
validated_kwargs: dict[str, typing.Any] = {k: v.validate(container) for k, v in provider_kwargs.items()}
validated_kwargs.update(static_kwargs)
return {
"bound_type": self.bound_type,
"creator": self._creator,
"self": self,
"kwargs": validated_kwargs,
"cache_settings": self.cache_settings,
}

def resolve(self, container: "Container") -> types.T_co:
container = container.find_container(self.scope)
cache_item = container.cache_registry.fetch_cache_item(self)
Expand Down
4 changes: 2 additions & 2 deletions skills/modern-di/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,6 @@ Check that all dependencies can be resolved without creating real objects:
```python
def test_wiring():
container = Container(scope=Scope.APP, groups=[Dependencies])
container.validate_provider(Dependencies.users_repository)
# Raises RuntimeError if any dependency is missing
container.validate()
# Raises if any provider has missing dependencies, missing alias sources, or circular deps
```
3 changes: 1 addition & 2 deletions tests/providers/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def test_alias_delegates_to_source() -> None:

def test_alias_resolve_provider() -> None:
container = Container(groups=[MyGroup])
container.validate_provider(MyGroup.abstract_repo)
instance = container.resolve_provider(MyGroup.abstract_repo)
assert isinstance(instance, PostgresRepository)

Expand Down Expand Up @@ -99,7 +98,7 @@ class G(Group):

container = Container(groups=[G])
with pytest.raises(AliasSourceNotRegisteredError, match="PostgresRepository"):
container.validate_provider(G.abstract)
container.resolve_provider(G.abstract)


def test_alias_missing_source_raises_on_container_validate() -> None:
Expand Down
1 change: 0 additions & 1 deletion tests/providers/test_container_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ def test_container_provider_direct_resolving() -> None:

request_container = app_container.build_child_container(scope=Scope.REQUEST)
assert request_container.resolve_provider(providers.container_provider) is request_container
request_container.validate_provider(providers.container_provider)


def test_container_provider_sub_dependency() -> None:
Expand Down
3 changes: 1 addition & 2 deletions tests/providers/test_context_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ class MyGroup(Group):

def test_context_provider() -> None:
now = datetime.datetime.now(tz=datetime.timezone.utc)
app_container = Container(context={datetime.datetime: now})
app_container.validate_provider(MyGroup.context_provider)
app_container = Container(groups=[MyGroup], context={datetime.datetime: now})
instance1 = app_container.resolve_provider(MyGroup.context_provider)
instance2 = app_container.resolve_provider(MyGroup.context_provider)
assert instance1 is instance2 is now
Expand Down
9 changes: 4 additions & 5 deletions tests/providers/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class MyGroup(Group):
def test_app_factory() -> None:
app_container = Container(groups=[MyGroup])
instance1 = app_container.resolve_provider(MyGroup.app_factory)
app_container.validate_provider(MyGroup.app_factory)
instance2 = app_container.resolve(dependency_type=SimpleCreator)
assert isinstance(instance1, SimpleCreator)
assert isinstance(instance2, SimpleCreator)
Expand All @@ -64,7 +63,7 @@ def test_app_factory_skip_creator_parsing() -> None:
def test_app_factory_unresolvable() -> None:
app_container = Container(groups=[MyGroup])
with pytest.raises(ArgumentResolutionError, match="Argument dep1 of type <class 'str'> cannot be resolved") as exc:
app_container.validate_provider(MyGroup.app_factory_unresolvable)
app_container.resolve_provider(MyGroup.app_factory_unresolvable)
assert exc.value.arg_name == "dep1"
assert exc.value.arg_type is str

Expand All @@ -78,16 +77,16 @@ def test_func_with_union_factory() -> None:
def test_func_with_broken_annotation() -> None:
app_container = Container(groups=[MyGroup])
with pytest.raises(ArgumentResolutionError, match="Argument dep1 of type None cannot be resolved"):
app_container.validate_provider(MyGroup.func_with_broken_annotation)
app_container.resolve_provider(MyGroup.func_with_broken_annotation)


def test_request_factory() -> None:
app_container = Container(groups=[MyGroup])
request_container = app_container.build_child_container(scope=Scope.REQUEST)
request_container.validate_provider(MyGroup.request_factory)
request_container.resolve_provider(MyGroup.request_factory)
instance1 = request_container.resolve_provider(MyGroup.request_factory)
instance2 = request_container.resolve_provider(MyGroup.request_factory)
request_container.validate_provider(MyGroup.request_factory)
request_container.resolve_provider(MyGroup.request_factory)
assert instance1 is not instance2

request_container = app_container.build_child_container(scope=Scope.REQUEST)
Expand Down
42 changes: 42 additions & 0 deletions tests/test_container.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import dataclasses
import typing

import pytest

Expand All @@ -11,6 +12,7 @@
ProviderNotRegisteredError,
ScopeSkippedError,
)
from modern_di.providers.abstract import AbstractProvider


def test_container_prevent_copy() -> None:
Expand Down Expand Up @@ -125,3 +127,43 @@ class ValidGroup(Group):

container = Container(groups=[ValidGroup])
container.validate() # should not raise


def test_validate_memoizes_diamond() -> None:
@dataclasses.dataclass(kw_only=True, slots=True)
class Bottom:
pass

@dataclasses.dataclass(kw_only=True, slots=True)
class Left:
bottom: Bottom

@dataclasses.dataclass(kw_only=True, slots=True)
class Right:
bottom: Bottom

@dataclasses.dataclass(kw_only=True, slots=True)
class Top:
left: Left
right: Right

bottom_provider = providers.Factory(creator=Bottom)
call_count = 0
original_get_dependencies = bottom_provider.get_dependencies

def counting_get_dependencies(container: Container) -> dict[str, AbstractProvider[typing.Any]]:
nonlocal call_count
call_count += 1
return original_get_dependencies(container)

bottom_provider.get_dependencies = counting_get_dependencies # ty: ignore[invalid-assignment]

class DiamondGroup(Group):
bottom = bottom_provider
left = providers.Factory(creator=Left)
right = providers.Factory(creator=Right)
top = providers.Factory(creator=Top)

container = Container(groups=[DiamondGroup])
container.validate()
assert call_count == 1
Loading