Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions docs/providers/alias.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Alias

`Alias` lets one type resolve to whatever provider already handles a different type. The most common use is binding an abstract base or `Protocol` to a concrete implementation that is already registered, without registering the implementation twice.

Resolving the alias delegates straight back through the container, so overrides and caching on the source provider apply transparently.

## Parameters

### source_type

The type whose registered provider should answer the call. At resolution time, the alias looks up `source_type` in the providers registry and delegates to that provider. If `source_type` is not registered, an `AliasSourceNotRegisteredError` is raised.

### bound_type

The type the alias is registered under in the providers registry — i.e. the type you pass to `container.resolve(...)`. Defaults to `source_type` (which makes the alias a no-op); set it to the abstract or `Protocol` type you want resolvable.

### scope

Standard scope parameter; defaults to `Scope.APP`. The alias does not enforce its own scope-based caching — the source provider's scope governs where the actual instance lives — so the practical effect of `scope` on `Alias` is limited. Setting it to match the source's scope is a reasonable convention.

## Basic Usage

```python
import dataclasses
from typing import Protocol

from modern_di import Container, Group, Scope, providers


class Repository(Protocol):
def fetch(self) -> list[str]: ...


@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
class PostgresRepository:
dsn: str = "postgres://localhost"

def fetch(self) -> list[str]:
return ["row-1", "row-2"]


class Dependencies(Group):
repo = providers.Factory(
creator=PostgresRepository,
cache_settings=providers.CacheSettings(),
)
abstract_repo = providers.Alias(
source_type=PostgresRepository,
bound_type=Repository,
)


container = Container(groups=[Dependencies])

concrete = container.resolve(PostgresRepository)
abstract = container.resolve(Repository)

# Both resolve to the same instance — the alias delegates to the
# cached source factory.
assert concrete is abstract
```

## Sharing the source's cache

Because `Alias` does not cache anything itself, callers automatically share whatever instance the source provider returns. With a cached `Factory`, every resolution path — by the concrete type, by the abstract type, or via a downstream factory parameter typed as the abstract — returns the same singleton.

With an uncached source `Factory`, each resolution still goes through the source factory, so each call produces a new instance (matching the source factory's own behavior).

## Overrides

Overrides are keyed by `provider_id`, so the alias and its source can be overridden independently:

```python
mock_for_alias = PostgresRepository(dsn="alias-mock")
container.override(Dependencies.abstract_repo, mock_for_alias)

assert container.resolve(Repository) is mock_for_alias
# The source provider is untouched.
assert container.resolve(PostgresRepository) is not mock_for_alias
```

Override the source provider instead, and both resolution paths see the mock:

```python
mock_for_source = PostgresRepository(dsn="source-mock")
container.override(Dependencies.repo, mock_for_source)

assert container.resolve(PostgresRepository) is mock_for_source
assert container.resolve(Repository) is mock_for_source
```

## Validation and cycle detection

`Alias` participates in `container.validate()` (and `Container(..., validate=True)`):

- If `source_type` is not registered, `AliasSourceNotRegisteredError` is raised eagerly.
- The alias reports the source provider as a dependency, so cycles that pass through an alias are detected and reported via `CircularDependencyError`.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ nav:
- Factories: providers/factories.md
- Context: providers/context.md
- Container: providers/container.md
- Alias: providers/alias.md
- Integrations:
- FastAPI: integrations/fastapi.md
- FastStream: integrations/faststream.md
Expand Down
4 changes: 4 additions & 0 deletions modern_di/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@
"2. Explicitly pass dependencies via the kwargs parameter to avoid automatic resolution\n"
"See https://modern-di.readthedocs.io/latest/troubleshooting/duplicate-type-error/ for more details"
)
ALIAS_SOURCE_NOT_REGISTERED_ERROR = (
"Alias source type {source_type} is not registered in providers registry. "
"Register a provider for {source_type} before defining the alias."
)
6 changes: 6 additions & 0 deletions modern_di/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ def __init__(
super().__init__(message)


class AliasSourceNotRegisteredError(ResolutionError):
def __init__(self, *, source_type: type) -> None:
self.source_type = source_type
super().__init__(errors.ALIAS_SOURCE_NOT_REGISTERED_ERROR.format(source_type=source_type))


class ArgumentResolutionError(ResolutionError):
def __init__(self, *, arg_name: str, arg_type: typing.Any, bound_type: typing.Any) -> None: # noqa: ANN401
self.arg_name = arg_name
Expand Down
2 changes: 2 additions & 0 deletions modern_di/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from modern_di.providers.abstract import AbstractProvider
from modern_di.providers.alias import Alias
from modern_di.providers.container_provider import container_provider
from modern_di.providers.context_provider import ContextProvider
from modern_di.providers.factory import CacheSettings, Factory


__all__ = [
"AbstractProvider",
"Alias",
"CacheSettings",
"ContextProvider",
"Factory",
Expand Down
47 changes: 47 additions & 0 deletions modern_di/providers/alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import typing

from modern_di import exceptions, types
from modern_di.providers.abstract import AbstractProvider
from modern_di.scope import Scope


if typing.TYPE_CHECKING:
from modern_di import Container


class Alias(AbstractProvider[types.T_co]):
__slots__ = [*AbstractProvider.BASE_SLOTS, "_source_type"]

def __init__(
self,
*,
source_type: type[types.T_co],
scope: Scope = Scope.APP,
bound_type: type | None = types.UNSET, # ty: ignore[invalid-parameter-default]
) -> None:
super().__init__(scope=scope, bound_type=bound_type if bound_type != types.UNSET else source_type)
self._source_type = source_type

def __repr__(self) -> str:
return f"Alias(source_type={self._source_type!r}, bound_type={self.bound_type!r}, scope={self.scope!r})"

def _find_source(self, container: "Container") -> "AbstractProvider[types.T_co]":
source = container.providers_registry.find_provider(self._source_type)
if source is None:
raise exceptions.AliasSourceNotRegisteredError(source_type=self._source_type)
return source

def get_dependencies(self, container: "Container") -> dict[str, "AbstractProvider[typing.Any]"]:
return {"source": self._find_source(container)}

def validate(self, container: "Container") -> dict[str, typing.Any]:
source = self._find_source(container)
return {
"bound_type": self.bound_type,
"source_type": self._source_type,
"source": source.validate(container),
"self": self,
}

def resolve(self, container: "Container") -> types.T_co:
return container.resolve_provider(self._find_source(container))
137 changes: 137 additions & 0 deletions tests/providers/test_alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import dataclasses

import pytest

from modern_di import Container, Group, Scope, providers
from modern_di.exceptions import (
AliasSourceNotRegisteredError,
CircularDependencyError,
ScopeNotInitializedError,
)


class AbstractRepository: ...


@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
class PostgresRepository(AbstractRepository):
dsn: str = "postgres://localhost"


class MyGroup(Group):
repo = providers.Factory(creator=PostgresRepository, cache_settings=providers.CacheSettings())
abstract_repo = providers.Alias(source_type=PostgresRepository, bound_type=AbstractRepository)


def test_alias_delegates_to_source() -> None:
container = Container(groups=[MyGroup], validate=True)
concrete = container.resolve(PostgresRepository)
abstract = container.resolve(AbstractRepository)
assert isinstance(abstract, PostgresRepository)
assert concrete is abstract


def test_alias_resolve_provider() -> None:
container = Container(groups=[MyGroup])
container.validate_provider(MyGroup.abstract_repo)
instance = container.resolve_provider(MyGroup.abstract_repo)
assert isinstance(instance, PostgresRepository)


def test_alias_without_caching_returns_fresh_instance_per_call() -> None:
class G(Group):
repo = providers.Factory(creator=PostgresRepository)
abstract = providers.Alias(source_type=PostgresRepository, bound_type=AbstractRepository)

container = Container(groups=[G])
a = container.resolve(AbstractRepository)
b = container.resolve(PostgresRepository)
assert isinstance(a, PostgresRepository)
assert isinstance(b, PostgresRepository)
assert a is not b


def test_alias_respects_source_scope() -> None:
class G(Group):
repo = providers.Factory(scope=Scope.REQUEST, creator=PostgresRepository)
abstract = providers.Alias(source_type=PostgresRepository, bound_type=AbstractRepository)

app_container = Container(groups=[G])
with pytest.raises(ScopeNotInitializedError):
app_container.resolve(AbstractRepository)

request_container = app_container.build_child_container(scope=Scope.REQUEST)
instance = request_container.resolve(AbstractRepository)
assert isinstance(instance, PostgresRepository)


def test_alias_override_does_not_affect_source() -> None:
container = Container(groups=[MyGroup])
mock = PostgresRepository(dsn="mock-alias")
container.override(MyGroup.abstract_repo, mock)

assert container.resolve(AbstractRepository) is mock
assert container.resolve(PostgresRepository) is not mock


def test_source_override_propagates_through_alias() -> None:
container = Container(groups=[MyGroup])
mock = PostgresRepository(dsn="mock-source")
container.override(MyGroup.repo, mock)

assert container.resolve(PostgresRepository) is mock
assert container.resolve(AbstractRepository) is mock


def test_alias_missing_source_raises_on_resolve() -> None:
class G(Group):
abstract = providers.Alias(source_type=PostgresRepository, bound_type=AbstractRepository)

container = Container(groups=[G])
with pytest.raises(AliasSourceNotRegisteredError, match="PostgresRepository") as exc:
container.resolve(AbstractRepository)
assert exc.value.source_type is PostgresRepository


def test_alias_missing_source_raises_on_validate_provider() -> None:
class G(Group):
abstract = providers.Alias(source_type=PostgresRepository, bound_type=AbstractRepository)

container = Container(groups=[G])
with pytest.raises(AliasSourceNotRegisteredError, match="PostgresRepository"):
container.validate_provider(G.abstract)


def test_alias_missing_source_raises_on_container_validate() -> None:
class G(Group):
abstract = providers.Alias(source_type=PostgresRepository, bound_type=AbstractRepository)

with pytest.raises(AliasSourceNotRegisteredError, match="PostgresRepository"):
Container(groups=[G], validate=True)


def test_alias_participates_in_cycle_detection() -> None:
class Iface: ...

@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
class Concrete(Iface):
dep: Iface

class G(Group):
concrete = providers.Factory(creator=Concrete)
iface_alias = providers.Alias(source_type=Concrete, bound_type=Iface)

with pytest.raises(CircularDependencyError, match="Concrete"):
Container(groups=[G], validate=True)


def test_alias_default_bound_type_is_source_type() -> None:
alias = providers.Alias(source_type=PostgresRepository)
assert alias.bound_type is PostgresRepository


def test_alias_repr() -> None:
alias = providers.Alias(source_type=PostgresRepository, bound_type=AbstractRepository, scope=Scope.REQUEST)
assert repr(alias) == (
f"Alias(source_type={PostgresRepository!r}, bound_type={AbstractRepository!r}, scope=<Scope.REQUEST: 3>)"
)
Loading