Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion modern_di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def find_container(self, scope: Scope) -> "typing_extensions.Self":
def resolve(self, dependency_type: type[types.T]) -> types.T:
provider = self.providers_registry.find_provider(dependency_type)
if not provider:
raise exceptions.ProviderNotRegisteredError(provider_type=dependency_type)
raise exceptions.ProviderNotRegisteredError(
provider_type=dependency_type,
suggestions=self.providers_registry.build_suggestions(dependency_type),
)

return self.resolve_provider(provider)

Expand Down
4 changes: 4 additions & 0 deletions modern_di/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
)
CONTAINER_SCOPE_IS_SKIPPED_ERROR = "Provider of scope {provider_scope} is skipped in the chain of containers."
CONTAINER_MISSING_PROVIDER_ERROR = "Provider of type {provider_type} is not registered in providers registry."
SUGGESTION_HEADER = "Did you mean:"
SUGGESTION_SUBCLASS = " - {type_name} (registered subclass, scope={scope})"
SUGGESTION_BASECLASS = " - {type_name} (registered base class, scope={scope})"
SUGGESTION_SIMILAR = " - {type_name} (similar name, scope={scope})"
FACTORY_ARGUMENT_RESOLUTION_ERROR = (
"Argument {arg_name} of type {arg_type} cannot be resolved. Trying to build dependency {bound_type}."
)
Expand Down
13 changes: 11 additions & 2 deletions modern_di/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,18 @@ def __str__(self) -> str:


class ProviderNotRegisteredError(ResolutionError):
def __init__(self, *, provider_type: type) -> None:
def __init__(
self,
*,
provider_type: type,
suggestions: list[str] | None = None,
) -> None:
self.provider_type = provider_type
super().__init__(errors.CONTAINER_MISSING_PROVIDER_ERROR.format(provider_type=provider_type))
self.suggestions = suggestions or []
message = errors.CONTAINER_MISSING_PROVIDER_ERROR.format(provider_type=provider_type)
if self.suggestions:
message += "\n" + errors.SUGGESTION_HEADER + "\n" + "\n".join(self.suggestions)
super().__init__(message)


class ArgumentResolutionError(ResolutionError):
Expand Down
54 changes: 53 additions & 1 deletion modern_di/registries/providers_registry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,29 @@
import difflib
import inspect
import typing

from modern_di import exceptions, types
from modern_di import errors, exceptions, types
from modern_di.providers.abstract import AbstractProvider


_MAX_SUGGESTIONS = 3
_SIMILARITY_CUTOFF = 0.6


def _hierarchy_hint(requested_type: type, provider: AbstractProvider[typing.Any]) -> str | None:
registered = provider.bound_type
if registered is None or not inspect.isclass(registered):
return None
try:
if issubclass(registered, requested_type):
return errors.SUGGESTION_SUBCLASS.format(type_name=registered.__name__, scope=provider.scope.name)
if issubclass(requested_type, registered):
return errors.SUGGESTION_BASECLASS.format(type_name=registered.__name__, scope=provider.scope.name)
except TypeError:
return None
return None


class ProvidersRegistry:
__slots__ = ("_providers",)

Expand Down Expand Up @@ -31,3 +51,35 @@ def add_providers(self, *args: AbstractProvider[typing.Any]) -> None:
continue

self.register(provider.bound_type, provider)

def build_suggestions(self, requested_type: type) -> list[str]:
requested_is_class = inspect.isclass(requested_type)
requested_name = getattr(requested_type, "__name__", str(requested_type))

hierarchy_hints: list[str] = []
name_to_provider: dict[str, AbstractProvider[typing.Any]] = {}

for provider in self._providers.values():
registered = provider.bound_type
if registered is None or registered is requested_type:
continue

hint = _hierarchy_hint(requested_type, provider) if requested_is_class else None
if hint is not None:
hierarchy_hints.append(hint)
if len(hierarchy_hints) >= _MAX_SUGGESTIONS:
return hierarchy_hints
continue

name = getattr(registered, "__name__", None)
if name:
name_to_provider[name] = provider

remaining = _MAX_SUGGESTIONS - len(hierarchy_hints)
typo_hints = [
errors.SUGGESTION_SIMILAR.format(type_name=name, scope=name_to_provider[name].scope.name)
for name in difflib.get_close_matches(
requested_name, name_to_provider.keys(), n=remaining, cutoff=_SIMILARITY_CUTOFF
)
]
return hierarchy_hints + typo_hints
172 changes: 172 additions & 0 deletions tests/test_suggestions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import dataclasses
import typing

import pytest

from modern_di import Container, Group, Scope, providers
from modern_di.exceptions import ProviderNotRegisteredError
from modern_di.registries.providers_registry import _hierarchy_hint


class Database:
pass


@dataclasses.dataclass(kw_only=True, slots=True)
class PostgresDatabase(Database):
pass


@dataclasses.dataclass(kw_only=True, slots=True)
class Repository:
pass


def test_subclass_suggestion() -> None:
class G(Group):
db = providers.Factory(creator=PostgresDatabase)

container = Container(groups=[G])
with pytest.raises(ProviderNotRegisteredError) as exc_info:
container.resolve(Database)

assert str(exc_info.value) == (
"Provider of type <class 'tests.test_suggestions.Database'> is not registered in providers registry.\n"
"Did you mean:\n"
" - PostgresDatabase (registered subclass, scope=APP)"
)


def test_baseclass_suggestion() -> None:
class G(Group):
db = providers.Factory(creator=Database)

container = Container(groups=[G])
with pytest.raises(ProviderNotRegisteredError) as exc_info:
container.resolve(PostgresDatabase)

assert str(exc_info.value) == (
"Provider of type <class 'tests.test_suggestions.PostgresDatabase'> is not registered in providers registry.\n"
"Did you mean:\n"
" - Database (registered base class, scope=APP)"
)


def test_typo_suggestion() -> None:
class G(Group):
repo = providers.Factory(creator=Repository)

@dataclasses.dataclass(kw_only=True, slots=True)
class Repostory:
pass

container = Container(groups=[G])
with pytest.raises(ProviderNotRegisteredError) as exc_info:
container.resolve(Repostory)

assert str(exc_info.value) == (
"Provider of type "
"<class 'tests.test_suggestions.test_typo_suggestion.<locals>.Repostory'> "
"is not registered in providers registry.\n"
"Did you mean:\n"
" - Repository (similar name, scope=APP)"
)


def test_suggestion_includes_provider_scope() -> None:
class G(Group):
db = providers.Factory(scope=Scope.REQUEST, creator=PostgresDatabase)

container = Container(groups=[G])
request_container = container.build_child_container(scope=Scope.REQUEST)
with pytest.raises(ProviderNotRegisteredError) as exc_info:
request_container.resolve(Database)

assert str(exc_info.value) == (
"Provider of type <class 'tests.test_suggestions.Database'> is not registered in providers registry.\n"
"Did you mean:\n"
" - PostgresDatabase (registered subclass, scope=REQUEST)"
)


def test_no_suggestions_when_nothing_matches() -> None:
container = Container()
with pytest.raises(ProviderNotRegisteredError) as exc_info:
container.resolve(int)

assert str(exc_info.value) == "Provider of type <class 'int'> is not registered in providers registry."


def test_suggestions_capped_at_three() -> None:
@dataclasses.dataclass(kw_only=True, slots=True)
class A1(Database):
pass

@dataclasses.dataclass(kw_only=True, slots=True)
class A2(Database):
pass

@dataclasses.dataclass(kw_only=True, slots=True)
class A3(Database):
pass

@dataclasses.dataclass(kw_only=True, slots=True)
class A4(Database):
pass

@dataclasses.dataclass(kw_only=True, slots=True)
class A5(Database):
pass

class G(Group):
a1 = providers.Factory(creator=A1)
a2 = providers.Factory(creator=A2)
a3 = providers.Factory(creator=A3)
a4 = providers.Factory(creator=A4)
a5 = providers.Factory(creator=A5)

container = Container(groups=[G])
with pytest.raises(ProviderNotRegisteredError) as exc_info:
container.resolve(Database)

assert str(exc_info.value) == (
"Provider of type <class 'tests.test_suggestions.Database'> is not registered in providers registry.\n"
"Did you mean:\n"
" - A1 (registered subclass, scope=APP)\n"
" - A2 (registered subclass, scope=APP)\n"
" - A3 (registered subclass, scope=APP)"
)


def test_hierarchy_hint_skips_non_class_bound_type() -> None:
provider = providers.Factory(creator=list, bound_type=list[int])
assert _hierarchy_hint(int, provider) is None


def test_hierarchy_hint_swallows_protocol_typeerror() -> None:
class MyProto(typing.Protocol):
def foo(self) -> None: ...

provider = providers.Factory(creator=lambda: 1, bound_type=int)
assert _hierarchy_hint(MyProto, provider) is None


def test_hierarchy_hint_preferred_over_typo() -> None:
@dataclasses.dataclass(kw_only=True, slots=True)
class Databse:
pass

class G(Group):
db = providers.Factory(creator=PostgresDatabase)
typo = providers.Factory(creator=Databse)

container = Container(groups=[G])
with pytest.raises(ProviderNotRegisteredError) as exc_info:
container.resolve(Database)

assert str(exc_info.value) == (
"Provider of type <class 'tests.test_suggestions.Database'> is not registered in providers registry.\n"
"Did you mean:\n"
" - PostgresDatabase (registered subclass, scope=APP)\n"
" - Databse (similar name, scope=APP)"
)
Loading