diff --git a/docs/providers/alias.md b/docs/providers/alias.md new file mode 100644 index 0000000..912e206 --- /dev/null +++ b/docs/providers/alias.md @@ -0,0 +1,97 @@ +# Alias + +`Alias` lets one type resolve to whatever provider already handles a different type. The most common use is binding an abstract base or `Protocol` to a concrete implementation that is already registered, without registering the implementation twice. + +Resolving the alias delegates straight back through the container, so overrides and caching on the source provider apply transparently. + +## Parameters + +### source_type + +The type whose registered provider should answer the call. At resolution time, the alias looks up `source_type` in the providers registry and delegates to that provider. If `source_type` is not registered, an `AliasSourceNotRegisteredError` is raised. + +### bound_type + +The type the alias is registered under in the providers registry — i.e. the type you pass to `container.resolve(...)`. Defaults to `source_type` (which makes the alias a no-op); set it to the abstract or `Protocol` type you want resolvable. + +### scope + +Standard scope parameter; defaults to `Scope.APP`. The alias does not enforce its own scope-based caching — the source provider's scope governs where the actual instance lives — so the practical effect of `scope` on `Alias` is limited. Setting it to match the source's scope is a reasonable convention. + +## Basic Usage + +```python +import dataclasses +from typing import Protocol + +from modern_di import Container, Group, Scope, providers + + +class Repository(Protocol): + def fetch(self) -> list[str]: ... + + +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class PostgresRepository: + dsn: str = "postgres://localhost" + + def fetch(self) -> list[str]: + return ["row-1", "row-2"] + + +class Dependencies(Group): + repo = providers.Factory( + creator=PostgresRepository, + cache_settings=providers.CacheSettings(), + ) + abstract_repo = providers.Alias( + source_type=PostgresRepository, + bound_type=Repository, + ) + + +container = Container(groups=[Dependencies]) + +concrete = container.resolve(PostgresRepository) +abstract = container.resolve(Repository) + +# Both resolve to the same instance — the alias delegates to the +# cached source factory. +assert concrete is abstract +``` + +## Sharing the source's cache + +Because `Alias` does not cache anything itself, callers automatically share whatever instance the source provider returns. With a cached `Factory`, every resolution path — by the concrete type, by the abstract type, or via a downstream factory parameter typed as the abstract — returns the same singleton. + +With an uncached source `Factory`, each resolution still goes through the source factory, so each call produces a new instance (matching the source factory's own behavior). + +## Overrides + +Overrides are keyed by `provider_id`, so the alias and its source can be overridden independently: + +```python +mock_for_alias = PostgresRepository(dsn="alias-mock") +container.override(Dependencies.abstract_repo, mock_for_alias) + +assert container.resolve(Repository) is mock_for_alias +# The source provider is untouched. +assert container.resolve(PostgresRepository) is not mock_for_alias +``` + +Override the source provider instead, and both resolution paths see the mock: + +```python +mock_for_source = PostgresRepository(dsn="source-mock") +container.override(Dependencies.repo, mock_for_source) + +assert container.resolve(PostgresRepository) is mock_for_source +assert container.resolve(Repository) is mock_for_source +``` + +## Validation and cycle detection + +`Alias` participates in `container.validate()` (and `Container(..., validate=True)`): + +- If `source_type` is not registered, `AliasSourceNotRegisteredError` is raised eagerly. +- The alias reports the source provider as a dependency, so cycles that pass through an alias are detected and reported via `CircularDependencyError`. diff --git a/mkdocs.yml b/mkdocs.yml index 1ef35b1..863c058 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -11,6 +11,7 @@ nav: - Factories: providers/factories.md - Context: providers/context.md - Container: providers/container.md + - Alias: providers/alias.md - Integrations: - FastAPI: integrations/fastapi.md - FastStream: integrations/faststream.md diff --git a/modern_di/errors.py b/modern_di/errors.py index 3a2ab6c..ace3d08 100644 --- a/modern_di/errors.py +++ b/modern_di/errors.py @@ -23,3 +23,7 @@ "2. Explicitly pass dependencies via the kwargs parameter to avoid automatic resolution\n" "See https://modern-di.readthedocs.io/latest/troubleshooting/duplicate-type-error/ for more details" ) +ALIAS_SOURCE_NOT_REGISTERED_ERROR = ( + "Alias source type {source_type} is not registered in providers registry. " + "Register a provider for {source_type} before defining the alias." +) diff --git a/modern_di/exceptions.py b/modern_di/exceptions.py index f3b895c..778ee66 100644 --- a/modern_di/exceptions.py +++ b/modern_di/exceptions.py @@ -102,6 +102,12 @@ def __init__( super().__init__(message) +class AliasSourceNotRegisteredError(ResolutionError): + def __init__(self, *, source_type: type) -> None: + self.source_type = source_type + super().__init__(errors.ALIAS_SOURCE_NOT_REGISTERED_ERROR.format(source_type=source_type)) + + class ArgumentResolutionError(ResolutionError): def __init__(self, *, arg_name: str, arg_type: typing.Any, bound_type: typing.Any) -> None: # noqa: ANN401 self.arg_name = arg_name diff --git a/modern_di/providers/__init__.py b/modern_di/providers/__init__.py index 2d99d0d..f9b5753 100644 --- a/modern_di/providers/__init__.py +++ b/modern_di/providers/__init__.py @@ -1,4 +1,5 @@ from modern_di.providers.abstract import AbstractProvider +from modern_di.providers.alias import Alias from modern_di.providers.container_provider import container_provider from modern_di.providers.context_provider import ContextProvider from modern_di.providers.factory import CacheSettings, Factory @@ -6,6 +7,7 @@ __all__ = [ "AbstractProvider", + "Alias", "CacheSettings", "ContextProvider", "Factory", diff --git a/modern_di/providers/alias.py b/modern_di/providers/alias.py new file mode 100644 index 0000000..ed4b61c --- /dev/null +++ b/modern_di/providers/alias.py @@ -0,0 +1,47 @@ +import typing + +from modern_di import exceptions, types +from modern_di.providers.abstract import AbstractProvider +from modern_di.scope import Scope + + +if typing.TYPE_CHECKING: + from modern_di import Container + + +class Alias(AbstractProvider[types.T_co]): + __slots__ = [*AbstractProvider.BASE_SLOTS, "_source_type"] + + def __init__( + self, + *, + source_type: type[types.T_co], + scope: Scope = Scope.APP, + bound_type: type | None = types.UNSET, # ty: ignore[invalid-parameter-default] + ) -> None: + super().__init__(scope=scope, bound_type=bound_type if bound_type != types.UNSET else source_type) + self._source_type = source_type + + def __repr__(self) -> str: + return f"Alias(source_type={self._source_type!r}, bound_type={self.bound_type!r}, scope={self.scope!r})" + + def _find_source(self, container: "Container") -> "AbstractProvider[types.T_co]": + source = container.providers_registry.find_provider(self._source_type) + if source is None: + raise exceptions.AliasSourceNotRegisteredError(source_type=self._source_type) + return source + + def get_dependencies(self, container: "Container") -> dict[str, "AbstractProvider[typing.Any]"]: + return {"source": self._find_source(container)} + + def validate(self, container: "Container") -> dict[str, typing.Any]: + source = self._find_source(container) + return { + "bound_type": self.bound_type, + "source_type": self._source_type, + "source": source.validate(container), + "self": self, + } + + def resolve(self, container: "Container") -> types.T_co: + return container.resolve_provider(self._find_source(container)) diff --git a/tests/providers/test_alias.py b/tests/providers/test_alias.py new file mode 100644 index 0000000..58d46c5 --- /dev/null +++ b/tests/providers/test_alias.py @@ -0,0 +1,137 @@ +import dataclasses + +import pytest + +from modern_di import Container, Group, Scope, providers +from modern_di.exceptions import ( + AliasSourceNotRegisteredError, + CircularDependencyError, + ScopeNotInitializedError, +) + + +class AbstractRepository: ... + + +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class PostgresRepository(AbstractRepository): + dsn: str = "postgres://localhost" + + +class MyGroup(Group): + repo = providers.Factory(creator=PostgresRepository, cache_settings=providers.CacheSettings()) + abstract_repo = providers.Alias(source_type=PostgresRepository, bound_type=AbstractRepository) + + +def test_alias_delegates_to_source() -> None: + container = Container(groups=[MyGroup], validate=True) + concrete = container.resolve(PostgresRepository) + abstract = container.resolve(AbstractRepository) + assert isinstance(abstract, PostgresRepository) + assert concrete is abstract + + +def test_alias_resolve_provider() -> None: + container = Container(groups=[MyGroup]) + container.validate_provider(MyGroup.abstract_repo) + instance = container.resolve_provider(MyGroup.abstract_repo) + assert isinstance(instance, PostgresRepository) + + +def test_alias_without_caching_returns_fresh_instance_per_call() -> None: + class G(Group): + repo = providers.Factory(creator=PostgresRepository) + abstract = providers.Alias(source_type=PostgresRepository, bound_type=AbstractRepository) + + container = Container(groups=[G]) + a = container.resolve(AbstractRepository) + b = container.resolve(PostgresRepository) + assert isinstance(a, PostgresRepository) + assert isinstance(b, PostgresRepository) + assert a is not b + + +def test_alias_respects_source_scope() -> None: + class G(Group): + repo = providers.Factory(scope=Scope.REQUEST, creator=PostgresRepository) + abstract = providers.Alias(source_type=PostgresRepository, bound_type=AbstractRepository) + + app_container = Container(groups=[G]) + with pytest.raises(ScopeNotInitializedError): + app_container.resolve(AbstractRepository) + + request_container = app_container.build_child_container(scope=Scope.REQUEST) + instance = request_container.resolve(AbstractRepository) + assert isinstance(instance, PostgresRepository) + + +def test_alias_override_does_not_affect_source() -> None: + container = Container(groups=[MyGroup]) + mock = PostgresRepository(dsn="mock-alias") + container.override(MyGroup.abstract_repo, mock) + + assert container.resolve(AbstractRepository) is mock + assert container.resolve(PostgresRepository) is not mock + + +def test_source_override_propagates_through_alias() -> None: + container = Container(groups=[MyGroup]) + mock = PostgresRepository(dsn="mock-source") + container.override(MyGroup.repo, mock) + + assert container.resolve(PostgresRepository) is mock + assert container.resolve(AbstractRepository) is mock + + +def test_alias_missing_source_raises_on_resolve() -> None: + class G(Group): + abstract = providers.Alias(source_type=PostgresRepository, bound_type=AbstractRepository) + + container = Container(groups=[G]) + with pytest.raises(AliasSourceNotRegisteredError, match="PostgresRepository") as exc: + container.resolve(AbstractRepository) + assert exc.value.source_type is PostgresRepository + + +def test_alias_missing_source_raises_on_validate_provider() -> None: + class G(Group): + abstract = providers.Alias(source_type=PostgresRepository, bound_type=AbstractRepository) + + container = Container(groups=[G]) + with pytest.raises(AliasSourceNotRegisteredError, match="PostgresRepository"): + container.validate_provider(G.abstract) + + +def test_alias_missing_source_raises_on_container_validate() -> None: + class G(Group): + abstract = providers.Alias(source_type=PostgresRepository, bound_type=AbstractRepository) + + with pytest.raises(AliasSourceNotRegisteredError, match="PostgresRepository"): + Container(groups=[G], validate=True) + + +def test_alias_participates_in_cycle_detection() -> None: + class Iface: ... + + @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) + class Concrete(Iface): + dep: Iface + + class G(Group): + concrete = providers.Factory(creator=Concrete) + iface_alias = providers.Alias(source_type=Concrete, bound_type=Iface) + + with pytest.raises(CircularDependencyError, match="Concrete"): + Container(groups=[G], validate=True) + + +def test_alias_default_bound_type_is_source_type() -> None: + alias = providers.Alias(source_type=PostgresRepository) + assert alias.bound_type is PostgresRepository + + +def test_alias_repr() -> None: + alias = providers.Alias(source_type=PostgresRepository, bound_type=AbstractRepository, scope=Scope.REQUEST) + assert repr(alias) == ( + f"Alias(source_type={PostgresRepository!r}, bound_type={AbstractRepository!r}, scope=)" + )