Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion modern_di/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import dataclasses
import typing

from modern_di import errors
from modern_di.scope import Scope


@dataclasses.dataclass(frozen=True, slots=True)
class ResolutionStep:
scope: Scope
name: str


class ModernDIError(RuntimeError):
"""Base class for all modern-di errors. Inherits from RuntimeError for backwards compatibility."""

Expand Down Expand Up @@ -51,7 +58,33 @@ def __init__(self, *, provider_scope: Scope) -> None:


class ResolutionError(ModernDIError):
"""Base class for errors raised while resolving a provider."""
"""Base class for errors raised while resolving a provider.

Carries an optional `dependency_path` accumulated as the error propagates up
the resolution chain, so the rendered message shows the full path from the
initially requested type down to the failing dependency.
"""

def __init__(self, message: str) -> None:
self._base_message = message
self.dependency_path: list[ResolutionStep] = []
super().__init__(message)

def prepend_step(self, step: ResolutionStep) -> None:
self.dependency_path.insert(0, step)
self.args = (str(self),)

def __str__(self) -> str:
if not self.dependency_path:
return self._base_message

scope_width = max(len(step.scope.name) for step in self.dependency_path)
lines = ["Cannot resolve dependency chain:"]
for i, step in enumerate(self.dependency_path):
prefix = "" if i == 0 else " " * (i - 1) + "└─> "
lines.append(f" {step.scope.name:<{scope_width}} {prefix}{step.name}")
lines.append(f" caused by: {self._base_message}")
return "\n".join(lines)


class ProviderNotRegisteredError(ResolutionError):
Expand Down
16 changes: 12 additions & 4 deletions modern_di/providers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def __init__( # noqa: PLR0913
def __repr__(self) -> str:
return f"Factory(creator={self._creator!r}, scope={self.scope!r}, cached={self.cache_settings is not None})"

def _resolution_step(self) -> exceptions.ResolutionStep:
name = self.bound_type.__name__ if self.bound_type else getattr(self._creator, "__name__", repr(self._creator))
return exceptions.ResolutionStep(scope=self.scope, name=name)

def _compile_kwargs(self, container: "Container") -> dict[str, typing.Any]:
result: dict[str, typing.Any] = {}
for k, v in self._parsed_kwargs.items():
Expand Down Expand Up @@ -140,10 +144,14 @@ def resolve(self, container: "Container") -> types.T_co:
if self.cache_settings and cache_item.cache is not None:
return cache_item.cache

provider_kwargs, static_kwargs = self._ensure_kwargs_cached(container, cache_item)
resolved_kwargs = dict(static_kwargs)
for k, v in provider_kwargs.items():
resolved_kwargs[k] = container.resolve_provider(v)
try:
provider_kwargs, static_kwargs = self._ensure_kwargs_cached(container, cache_item)
resolved_kwargs = dict(static_kwargs)
for k, v in provider_kwargs.items():
resolved_kwargs[k] = container.resolve_provider(v)
except exceptions.ResolutionError as exc:
exc.prepend_step(self._resolution_step())
raise

if not self.cache_settings:
return self._creator(**resolved_kwargs)
Expand Down
72 changes: 72 additions & 0 deletions tests/test_dependency_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import dataclasses

import pytest

from modern_di import Container, Group, Scope, providers
from modern_di.exceptions import ArgumentResolutionError, ProviderNotRegisteredError, ResolutionStep


@dataclasses.dataclass(kw_only=True, slots=True)
class Database:
pass


@dataclasses.dataclass(kw_only=True, slots=True)
class Repository:
db: Database


@dataclasses.dataclass(kw_only=True, slots=True)
class MyService:
repo: Repository


class IncompleteGroup(Group):
repo = providers.Factory(creator=Repository)
svc = providers.Factory(creator=MyService)


def test_chain_appears_when_arg_unresolvable() -> None:
container = Container(groups=[IncompleteGroup])
with pytest.raises(ArgumentResolutionError) as exc_info:
container.resolve(MyService)

exc = exc_info.value
assert exc.dependency_path == [
ResolutionStep(scope=Scope.APP, name="MyService"),
ResolutionStep(scope=Scope.APP, name="Repository"),
]
assert str(exc) == (
"Cannot resolve dependency chain:\n"
" APP MyService\n"
" APP └─> Repository\n"
" caused by: Argument db of type <class 'tests.test_dependency_path.Database'> "
"cannot be resolved. Trying to build dependency <class 'tests.test_dependency_path.Repository'>."
)


def test_no_chain_when_top_level_provider_missing() -> None:
container = Container()
with pytest.raises(ProviderNotRegisteredError) as exc_info:
container.resolve(str)
assert exc_info.value.dependency_path == []
assert "Cannot resolve dependency chain" not in str(exc_info.value)


def test_chain_includes_scope_name() -> None:
@dataclasses.dataclass(kw_only=True, slots=True)
class Outer:
inner: Repository

class CrossScope(Group):
repo = providers.Factory(scope=Scope.REQUEST, creator=Repository)
outer = providers.Factory(scope=Scope.REQUEST, creator=Outer)

container = Container(groups=[CrossScope])
request = container.build_child_container(scope=Scope.REQUEST)
with pytest.raises(ArgumentResolutionError) as exc_info:
request.resolve(Outer)

rendered = str(exc_info.value)
assert "REQUEST" in rendered
assert exc_info.value.dependency_path[0].scope == Scope.REQUEST
Loading