diff --git a/modern_di/container.py b/modern_di/container.py index 8ecf828..50512bf 100644 --- a/modern_di/container.py +++ b/modern_di/container.py @@ -86,7 +86,10 @@ def find_container(self, scope: Scope) -> "typing_extensions.Self": def resolve(self, dependency_type: type[types.T]) -> types.T: provider = self.providers_registry.find_provider(dependency_type) if not provider: - raise exceptions.ProviderNotRegisteredError(provider_type=dependency_type) + raise exceptions.ProviderNotRegisteredError( + provider_type=dependency_type, + suggestions=self.providers_registry.build_suggestions(dependency_type), + ) return self.resolve_provider(provider) diff --git a/modern_di/errors.py b/modern_di/errors.py index 17c2d25..3a2ab6c 100644 --- a/modern_di/errors.py +++ b/modern_di/errors.py @@ -8,6 +8,10 @@ ) CONTAINER_SCOPE_IS_SKIPPED_ERROR = "Provider of scope {provider_scope} is skipped in the chain of containers." CONTAINER_MISSING_PROVIDER_ERROR = "Provider of type {provider_type} is not registered in providers registry." +SUGGESTION_HEADER = "Did you mean:" +SUGGESTION_SUBCLASS = " - {type_name} (registered subclass, scope={scope})" +SUGGESTION_BASECLASS = " - {type_name} (registered base class, scope={scope})" +SUGGESTION_SIMILAR = " - {type_name} (similar name, scope={scope})" FACTORY_ARGUMENT_RESOLUTION_ERROR = ( "Argument {arg_name} of type {arg_type} cannot be resolved. Trying to build dependency {bound_type}." ) diff --git a/modern_di/exceptions.py b/modern_di/exceptions.py index 6478c8f..f3b895c 100644 --- a/modern_di/exceptions.py +++ b/modern_di/exceptions.py @@ -88,9 +88,18 @@ def __str__(self) -> str: class ProviderNotRegisteredError(ResolutionError): - def __init__(self, *, provider_type: type) -> None: + def __init__( + self, + *, + provider_type: type, + suggestions: list[str] | None = None, + ) -> None: self.provider_type = provider_type - super().__init__(errors.CONTAINER_MISSING_PROVIDER_ERROR.format(provider_type=provider_type)) + self.suggestions = suggestions or [] + message = errors.CONTAINER_MISSING_PROVIDER_ERROR.format(provider_type=provider_type) + if self.suggestions: + message += "\n" + errors.SUGGESTION_HEADER + "\n" + "\n".join(self.suggestions) + super().__init__(message) class ArgumentResolutionError(ResolutionError): diff --git a/modern_di/registries/providers_registry.py b/modern_di/registries/providers_registry.py index 9ace295..3b1d1b7 100644 --- a/modern_di/registries/providers_registry.py +++ b/modern_di/registries/providers_registry.py @@ -1,9 +1,29 @@ +import difflib +import inspect import typing -from modern_di import exceptions, types +from modern_di import errors, exceptions, types from modern_di.providers.abstract import AbstractProvider +_MAX_SUGGESTIONS = 3 +_SIMILARITY_CUTOFF = 0.6 + + +def _hierarchy_hint(requested_type: type, provider: AbstractProvider[typing.Any]) -> str | None: + registered = provider.bound_type + if registered is None or not inspect.isclass(registered): + return None + try: + if issubclass(registered, requested_type): + return errors.SUGGESTION_SUBCLASS.format(type_name=registered.__name__, scope=provider.scope.name) + if issubclass(requested_type, registered): + return errors.SUGGESTION_BASECLASS.format(type_name=registered.__name__, scope=provider.scope.name) + except TypeError: + return None + return None + + class ProvidersRegistry: __slots__ = ("_providers",) @@ -31,3 +51,35 @@ def add_providers(self, *args: AbstractProvider[typing.Any]) -> None: continue self.register(provider.bound_type, provider) + + def build_suggestions(self, requested_type: type) -> list[str]: + requested_is_class = inspect.isclass(requested_type) + requested_name = getattr(requested_type, "__name__", str(requested_type)) + + hierarchy_hints: list[str] = [] + name_to_provider: dict[str, AbstractProvider[typing.Any]] = {} + + for provider in self._providers.values(): + registered = provider.bound_type + if registered is None or registered is requested_type: + continue + + hint = _hierarchy_hint(requested_type, provider) if requested_is_class else None + if hint is not None: + hierarchy_hints.append(hint) + if len(hierarchy_hints) >= _MAX_SUGGESTIONS: + return hierarchy_hints + continue + + name = getattr(registered, "__name__", None) + if name: + name_to_provider[name] = provider + + remaining = _MAX_SUGGESTIONS - len(hierarchy_hints) + typo_hints = [ + errors.SUGGESTION_SIMILAR.format(type_name=name, scope=name_to_provider[name].scope.name) + for name in difflib.get_close_matches( + requested_name, name_to_provider.keys(), n=remaining, cutoff=_SIMILARITY_CUTOFF + ) + ] + return hierarchy_hints + typo_hints diff --git a/tests/test_suggestions.py b/tests/test_suggestions.py new file mode 100644 index 0000000..3c3ff16 --- /dev/null +++ b/tests/test_suggestions.py @@ -0,0 +1,172 @@ +import dataclasses +import typing + +import pytest + +from modern_di import Container, Group, Scope, providers +from modern_di.exceptions import ProviderNotRegisteredError +from modern_di.registries.providers_registry import _hierarchy_hint + + +class Database: + pass + + +@dataclasses.dataclass(kw_only=True, slots=True) +class PostgresDatabase(Database): + pass + + +@dataclasses.dataclass(kw_only=True, slots=True) +class Repository: + pass + + +def test_subclass_suggestion() -> None: + class G(Group): + db = providers.Factory(creator=PostgresDatabase) + + container = Container(groups=[G]) + with pytest.raises(ProviderNotRegisteredError) as exc_info: + container.resolve(Database) + + assert str(exc_info.value) == ( + "Provider of type is not registered in providers registry.\n" + "Did you mean:\n" + " - PostgresDatabase (registered subclass, scope=APP)" + ) + + +def test_baseclass_suggestion() -> None: + class G(Group): + db = providers.Factory(creator=Database) + + container = Container(groups=[G]) + with pytest.raises(ProviderNotRegisteredError) as exc_info: + container.resolve(PostgresDatabase) + + assert str(exc_info.value) == ( + "Provider of type is not registered in providers registry.\n" + "Did you mean:\n" + " - Database (registered base class, scope=APP)" + ) + + +def test_typo_suggestion() -> None: + class G(Group): + repo = providers.Factory(creator=Repository) + + @dataclasses.dataclass(kw_only=True, slots=True) + class Repostory: + pass + + container = Container(groups=[G]) + with pytest.raises(ProviderNotRegisteredError) as exc_info: + container.resolve(Repostory) + + assert str(exc_info.value) == ( + "Provider of type " + ".Repostory'> " + "is not registered in providers registry.\n" + "Did you mean:\n" + " - Repository (similar name, scope=APP)" + ) + + +def test_suggestion_includes_provider_scope() -> None: + class G(Group): + db = providers.Factory(scope=Scope.REQUEST, creator=PostgresDatabase) + + container = Container(groups=[G]) + request_container = container.build_child_container(scope=Scope.REQUEST) + with pytest.raises(ProviderNotRegisteredError) as exc_info: + request_container.resolve(Database) + + assert str(exc_info.value) == ( + "Provider of type is not registered in providers registry.\n" + "Did you mean:\n" + " - PostgresDatabase (registered subclass, scope=REQUEST)" + ) + + +def test_no_suggestions_when_nothing_matches() -> None: + container = Container() + with pytest.raises(ProviderNotRegisteredError) as exc_info: + container.resolve(int) + + assert str(exc_info.value) == "Provider of type is not registered in providers registry." + + +def test_suggestions_capped_at_three() -> None: + @dataclasses.dataclass(kw_only=True, slots=True) + class A1(Database): + pass + + @dataclasses.dataclass(kw_only=True, slots=True) + class A2(Database): + pass + + @dataclasses.dataclass(kw_only=True, slots=True) + class A3(Database): + pass + + @dataclasses.dataclass(kw_only=True, slots=True) + class A4(Database): + pass + + @dataclasses.dataclass(kw_only=True, slots=True) + class A5(Database): + pass + + class G(Group): + a1 = providers.Factory(creator=A1) + a2 = providers.Factory(creator=A2) + a3 = providers.Factory(creator=A3) + a4 = providers.Factory(creator=A4) + a5 = providers.Factory(creator=A5) + + container = Container(groups=[G]) + with pytest.raises(ProviderNotRegisteredError) as exc_info: + container.resolve(Database) + + assert str(exc_info.value) == ( + "Provider of type is not registered in providers registry.\n" + "Did you mean:\n" + " - A1 (registered subclass, scope=APP)\n" + " - A2 (registered subclass, scope=APP)\n" + " - A3 (registered subclass, scope=APP)" + ) + + +def test_hierarchy_hint_skips_non_class_bound_type() -> None: + provider = providers.Factory(creator=list, bound_type=list[int]) + assert _hierarchy_hint(int, provider) is None + + +def test_hierarchy_hint_swallows_protocol_typeerror() -> None: + class MyProto(typing.Protocol): + def foo(self) -> None: ... + + provider = providers.Factory(creator=lambda: 1, bound_type=int) + assert _hierarchy_hint(MyProto, provider) is None + + +def test_hierarchy_hint_preferred_over_typo() -> None: + @dataclasses.dataclass(kw_only=True, slots=True) + class Databse: + pass + + class G(Group): + db = providers.Factory(creator=PostgresDatabase) + typo = providers.Factory(creator=Databse) + + container = Container(groups=[G]) + with pytest.raises(ProviderNotRegisteredError) as exc_info: + container.resolve(Database) + + assert str(exc_info.value) == ( + "Provider of type is not registered in providers registry.\n" + "Did you mean:\n" + " - PostgresDatabase (registered subclass, scope=APP)\n" + " - Databse (similar name, scope=APP)" + )