Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/dev/key-concepts.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,28 @@
- For lifetime less than **ACTION**;
- Must be managed manually.

### Custom scopes

For non-standard lifecycles (per-tenant containers, background jobs, etc.) you
can pass any `IntEnum` value where `Scope` is accepted:

```python
from enum import IntEnum
from modern_di import Container, Scope, providers

class MyScope(IntEnum):
TENANT = 6
BACKGROUND_JOB = 7

provider = providers.Factory(scope=MyScope.TENANT, creator=...)
tenant_container = Container().build_child_container(scope=MyScope.TENANT)
```

A child scope's integer value must be strictly greater than its parent's. The
auto-derive of the next scope (when `scope` is omitted from
`build_child_container`) only advances within the parent's own enum class — to
cross enum boundaries, pass `scope=` explicitly.

### How to choose scope

Provider's scope must be max value between scopes of all its dependencies.
Expand Down
11 changes: 6 additions & 5 deletions modern_di/container.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import threading
import typing

Expand Down Expand Up @@ -28,7 +29,7 @@ class Container:

def __init__( # noqa: PLR0913
self,
scope: Scope = Scope.APP,
scope: enum.IntEnum = Scope.APP,
parent_container: typing.Optional["typing_extensions.Self"] = None,
context: dict[type[typing.Any], typing.Any] | None = None,
groups: list[type[Group]] | None = None,
Expand All @@ -38,7 +39,7 @@ def __init__( # noqa: PLR0913
self.lock = threading.Lock() if use_lock else None
self.scope = scope
self.parent_container = parent_container
self.scope_map: dict[Scope, typing_extensions.Self] = (
self.scope_map: dict[enum.IntEnum, typing_extensions.Self] = (
{**parent_container.scope_map, scope: self} if parent_container else {scope: self}
)
self.cache_registry = CacheRegistry()
Expand All @@ -59,13 +60,13 @@ def __init__( # noqa: PLR0913
self.validate()

def build_child_container(
self, context: dict[type[typing.Any], typing.Any] | None = None, scope: Scope | None = None
self, context: dict[type[typing.Any], typing.Any] | None = None, scope: enum.IntEnum | None = None
) -> "typing_extensions.Self":
if scope and scope <= self.scope:
raise exceptions.InvalidChildScopeError(
parent_scope=self.scope,
child_scope=scope,
allowed_scopes=[x.name for x in Scope if x > self.scope],
allowed_scopes=[x.name for x in type(self.scope) if x > self.scope],
)

if not scope:
Expand All @@ -76,7 +77,7 @@ def build_child_container(

return self.__class__(scope=scope, parent_container=self, context=context)

def find_container(self, scope: Scope) -> "typing_extensions.Self":
def find_container(self, scope: enum.IntEnum) -> "typing_extensions.Self":
if scope not in self.scope_map:
if scope > self.scope:
raise exceptions.ScopeNotInitializedError(provider_scope=scope, container_scope=self.scope)
Expand Down
12 changes: 6 additions & 6 deletions modern_di/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import dataclasses
import enum
import typing

from modern_di import errors
from modern_di.scope import Scope


@dataclasses.dataclass(frozen=True, slots=True)
class ResolutionStep:
scope: Scope
scope: enum.IntEnum
name: str


Expand All @@ -20,7 +20,7 @@ class ContainerError(ModernDIError):


class InvalidChildScopeError(ContainerError):
def __init__(self, *, parent_scope: Scope, child_scope: Scope, allowed_scopes: list[str]) -> None:
def __init__(self, *, parent_scope: enum.IntEnum, child_scope: enum.IntEnum, allowed_scopes: list[str]) -> None:
self.parent_scope = parent_scope
self.child_scope = child_scope
self.allowed_scopes = allowed_scopes
Expand All @@ -34,13 +34,13 @@ def __init__(self, *, parent_scope: Scope, child_scope: Scope, allowed_scopes: l


class MaxScopeReachedError(ContainerError):
def __init__(self, *, parent_scope: Scope) -> None:
def __init__(self, *, parent_scope: enum.IntEnum) -> None:
self.parent_scope = parent_scope
super().__init__(errors.CONTAINER_MAX_SCOPE_REACHED_ERROR.format(parent_scope=parent_scope.name))


class ScopeNotInitializedError(ContainerError):
def __init__(self, *, provider_scope: Scope, container_scope: Scope) -> None:
def __init__(self, *, provider_scope: enum.IntEnum, container_scope: enum.IntEnum) -> None:
self.provider_scope = provider_scope
self.container_scope = container_scope
super().__init__(
Expand All @@ -52,7 +52,7 @@ def __init__(self, *, provider_scope: Scope, container_scope: Scope) -> None:


class ScopeSkippedError(ContainerError):
def __init__(self, *, provider_scope: Scope) -> None:
def __init__(self, *, provider_scope: enum.IntEnum) -> None:
self.provider_scope = provider_scope
super().__init__(errors.CONTAINER_SCOPE_IS_SKIPPED_ERROR.format(provider_scope=provider_scope.name))

Expand Down
4 changes: 2 additions & 2 deletions modern_di/providers/abstract.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import abc
import enum
import itertools
import typing

from modern_di import types
from modern_di.scope import Scope


if typing.TYPE_CHECKING:
Expand All @@ -18,7 +18,7 @@ class AbstractProvider(abc.ABC, typing.Generic[types.T_co]):
def __init__(
self,
*,
scope: Scope,
scope: enum.IntEnum,
bound_type: type | None,
) -> None:
self.scope = scope
Expand Down
3 changes: 2 additions & 1 deletion modern_di/providers/alias.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import typing

from modern_di import exceptions, types
Expand All @@ -16,7 +17,7 @@ def __init__(
self,
*,
source_type: type[types.T_co],
scope: Scope = Scope.APP,
scope: enum.IntEnum = Scope.APP,
bound_type: type | None = types.UNSET, # ty: ignore[invalid-parameter-default]
) -> None:
super().__init__(scope=scope, bound_type=bound_type if bound_type != types.UNSET else source_type)
Expand Down
3 changes: 2 additions & 1 deletion modern_di/providers/context_provider.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import typing

from modern_di import types
Expand All @@ -15,7 +16,7 @@ class ContextProvider(AbstractProvider[types.T_co]):
def __init__(
self,
*,
scope: Scope = Scope.APP,
scope: enum.IntEnum = Scope.APP,
context_type: type[types.T_co],
bound_type: type | None = types.UNSET, # ty: ignore[invalid-parameter-default]
) -> None:
Expand Down
3 changes: 2 additions & 1 deletion modern_di/providers/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import enum
import inspect
import typing
import warnings
Expand Down Expand Up @@ -31,7 +32,7 @@ class Factory(AbstractProvider[types.T_co]):
def __init__( # noqa: PLR0913
self,
*,
scope: Scope = Scope.APP,
scope: enum.IntEnum = Scope.APP,
creator: typing.Callable[..., types.T_co],
bound_type: type | None = types.UNSET, # ty: ignore[invalid-parameter-default]
kwargs: dict[str, typing.Any] | None = None,
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/test_container_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_container_provider_direct_resolving() -> None:

def test_container_provider_sub_dependency() -> None:
def creator(di_container: Container) -> Scope:
return di_container.scope
return Scope(di_container.scope)

class MyGroup(Group):
factory = providers.Factory(scope=Scope.REQUEST, creator=creator)
Expand Down
110 changes: 110 additions & 0 deletions tests/test_custom_scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import dataclasses
import enum

import pytest

from modern_di import Container, Group, Scope, providers
from modern_di.exceptions import (
InvalidChildScopeError,
ScopeNotInitializedError,
ScopeSkippedError,
)


class MyScope(enum.IntEnum):
TENANT = 6
BACKGROUND_JOB = 7


class ConflictingScope(enum.IntEnum):
SAME_AS_APP = 1
LOWER_THAN_REQUEST = 2


@dataclasses.dataclass(kw_only=True, slots=True)
class TenantService:
pass


def test_build_child_at_custom_scope_from_step() -> None:
step_container = Container(scope=Scope.STEP)
tenant_container = step_container.build_child_container(scope=MyScope.TENANT)
assert tenant_container.scope is MyScope.TENANT
assert tenant_container.parent_container is step_container


def test_build_child_at_custom_scope_from_app_skips_intermediate() -> None:
app_container = Container()
tenant_container = app_container.build_child_container(scope=MyScope.TENANT)
assert tenant_container.scope is MyScope.TENANT


def test_factory_resolves_through_custom_scope_container() -> None:
class TenantGroup(Group):
svc = providers.Factory(scope=MyScope.TENANT, creator=TenantService)

app_container = Container(groups=[TenantGroup])
tenant_container = app_container.build_child_container(scope=MyScope.TENANT)

instance = tenant_container.resolve(TenantService)
assert isinstance(instance, TenantService)


def test_resolve_at_custom_scope_from_app_raises_scope_not_initialized() -> None:
class TenantGroup(Group):
svc = providers.Factory(scope=MyScope.TENANT, creator=TenantService)

app_container = Container(groups=[TenantGroup])
with pytest.raises(ScopeNotInitializedError, match="TENANT") as exc:
app_container.resolve(TenantService)
assert exc.value.provider_scope is MyScope.TENANT
assert exc.value.container_scope is Scope.APP


def test_resolve_app_provider_from_custom_scope_with_skipped_chain() -> None:
# A standalone tenant container that never went through APP -> ... chain
tenant_container = Container(scope=MyScope.TENANT)
app_factory = providers.Factory(creator=lambda: "x")
with pytest.raises(ScopeSkippedError, match="APP"):
tenant_container.resolve_provider(app_factory)


def test_invalid_child_scope_uses_parent_enum_for_allowed_list() -> None:
tenant_container = Container(scope=MyScope.TENANT)
with pytest.raises(InvalidChildScopeError) as exc:
tenant_container.build_child_container(scope=MyScope.TENANT)
# allowed_scopes must be drawn from the parent's own enum class (MyScope),
# not the standard Scope enum.
assert exc.value.allowed_scopes == ["BACKGROUND_JOB"]


def test_invalid_child_scope_with_conflicting_value() -> None:
app_container = Container()
with pytest.raises(InvalidChildScopeError) as exc:
app_container.build_child_container(scope=ConflictingScope.SAME_AS_APP)
assert exc.value.parent_scope is Scope.APP
assert exc.value.child_scope is ConflictingScope.SAME_AS_APP


def test_caching_isolated_across_tenant_containers() -> None:
class TenantGroup(Group):
svc = providers.Factory(
scope=MyScope.TENANT,
creator=TenantService,
cache_settings=providers.CacheSettings(),
)

app_container = Container(groups=[TenantGroup])
tenant_a = app_container.build_child_container(scope=MyScope.TENANT)
tenant_b = app_container.build_child_container(scope=MyScope.TENANT)

instance_a = tenant_a.resolve(TenantService)
instance_b = tenant_b.resolve(TenantService)
assert instance_a is not instance_b
assert tenant_a.resolve(TenantService) is instance_a


def test_auto_derive_within_custom_enum() -> None:
tenant_container = Container(scope=MyScope.TENANT)
bg_container = tenant_container.build_child_container()
assert bg_container.scope is MyScope.BACKGROUND_JOB
Loading