diff --git a/docs/dev/key-concepts.md b/docs/dev/key-concepts.md index 68c9bf1..fdae4d3 100644 --- a/docs/dev/key-concepts.md +++ b/docs/dev/key-concepts.md @@ -36,6 +36,28 @@ - For lifetime less than **ACTION**; - Must be managed manually. +### Custom scopes + +For non-standard lifecycles (per-tenant containers, background jobs, etc.) you +can pass any `IntEnum` value where `Scope` is accepted: + +```python +from enum import IntEnum +from modern_di import Container, Scope, providers + +class MyScope(IntEnum): + TENANT = 6 + BACKGROUND_JOB = 7 + +provider = providers.Factory(scope=MyScope.TENANT, creator=...) +tenant_container = Container().build_child_container(scope=MyScope.TENANT) +``` + +A child scope's integer value must be strictly greater than its parent's. The +auto-derive of the next scope (when `scope` is omitted from +`build_child_container`) only advances within the parent's own enum class — to +cross enum boundaries, pass `scope=` explicitly. + ### How to choose scope Provider's scope must be max value between scopes of all its dependencies. diff --git a/modern_di/container.py b/modern_di/container.py index 010aae3..cfb60e0 100644 --- a/modern_di/container.py +++ b/modern_di/container.py @@ -1,3 +1,4 @@ +import enum import threading import typing @@ -28,7 +29,7 @@ class Container: def __init__( # noqa: PLR0913 self, - scope: Scope = Scope.APP, + scope: enum.IntEnum = Scope.APP, parent_container: typing.Optional["typing_extensions.Self"] = None, context: dict[type[typing.Any], typing.Any] | None = None, groups: list[type[Group]] | None = None, @@ -38,7 +39,7 @@ def __init__( # noqa: PLR0913 self.lock = threading.Lock() if use_lock else None self.scope = scope self.parent_container = parent_container - self.scope_map: dict[Scope, typing_extensions.Self] = ( + self.scope_map: dict[enum.IntEnum, typing_extensions.Self] = ( {**parent_container.scope_map, scope: self} if parent_container else {scope: self} ) self.cache_registry = CacheRegistry() @@ -59,13 +60,13 @@ def __init__( # noqa: PLR0913 self.validate() def build_child_container( - self, context: dict[type[typing.Any], typing.Any] | None = None, scope: Scope | None = None + self, context: dict[type[typing.Any], typing.Any] | None = None, scope: enum.IntEnum | None = None ) -> "typing_extensions.Self": if scope and scope <= self.scope: raise exceptions.InvalidChildScopeError( parent_scope=self.scope, child_scope=scope, - allowed_scopes=[x.name for x in Scope if x > self.scope], + allowed_scopes=[x.name for x in type(self.scope) if x > self.scope], ) if not scope: @@ -76,7 +77,7 @@ def build_child_container( return self.__class__(scope=scope, parent_container=self, context=context) - def find_container(self, scope: Scope) -> "typing_extensions.Self": + def find_container(self, scope: enum.IntEnum) -> "typing_extensions.Self": if scope not in self.scope_map: if scope > self.scope: raise exceptions.ScopeNotInitializedError(provider_scope=scope, container_scope=self.scope) diff --git a/modern_di/exceptions.py b/modern_di/exceptions.py index 778ee66..5ae3319 100644 --- a/modern_di/exceptions.py +++ b/modern_di/exceptions.py @@ -1,13 +1,13 @@ import dataclasses +import enum import typing from modern_di import errors -from modern_di.scope import Scope @dataclasses.dataclass(frozen=True, slots=True) class ResolutionStep: - scope: Scope + scope: enum.IntEnum name: str @@ -20,7 +20,7 @@ class ContainerError(ModernDIError): class InvalidChildScopeError(ContainerError): - def __init__(self, *, parent_scope: Scope, child_scope: Scope, allowed_scopes: list[str]) -> None: + def __init__(self, *, parent_scope: enum.IntEnum, child_scope: enum.IntEnum, allowed_scopes: list[str]) -> None: self.parent_scope = parent_scope self.child_scope = child_scope self.allowed_scopes = allowed_scopes @@ -34,13 +34,13 @@ def __init__(self, *, parent_scope: Scope, child_scope: Scope, allowed_scopes: l class MaxScopeReachedError(ContainerError): - def __init__(self, *, parent_scope: Scope) -> None: + def __init__(self, *, parent_scope: enum.IntEnum) -> None: self.parent_scope = parent_scope super().__init__(errors.CONTAINER_MAX_SCOPE_REACHED_ERROR.format(parent_scope=parent_scope.name)) class ScopeNotInitializedError(ContainerError): - def __init__(self, *, provider_scope: Scope, container_scope: Scope) -> None: + def __init__(self, *, provider_scope: enum.IntEnum, container_scope: enum.IntEnum) -> None: self.provider_scope = provider_scope self.container_scope = container_scope super().__init__( @@ -52,7 +52,7 @@ def __init__(self, *, provider_scope: Scope, container_scope: Scope) -> None: class ScopeSkippedError(ContainerError): - def __init__(self, *, provider_scope: Scope) -> None: + def __init__(self, *, provider_scope: enum.IntEnum) -> None: self.provider_scope = provider_scope super().__init__(errors.CONTAINER_SCOPE_IS_SKIPPED_ERROR.format(provider_scope=provider_scope.name)) diff --git a/modern_di/providers/abstract.py b/modern_di/providers/abstract.py index 2e7265b..d48e78f 100644 --- a/modern_di/providers/abstract.py +++ b/modern_di/providers/abstract.py @@ -1,9 +1,9 @@ import abc +import enum import itertools import typing from modern_di import types -from modern_di.scope import Scope if typing.TYPE_CHECKING: @@ -18,7 +18,7 @@ class AbstractProvider(abc.ABC, typing.Generic[types.T_co]): def __init__( self, *, - scope: Scope, + scope: enum.IntEnum, bound_type: type | None, ) -> None: self.scope = scope diff --git a/modern_di/providers/alias.py b/modern_di/providers/alias.py index d53d23a..2999ef2 100644 --- a/modern_di/providers/alias.py +++ b/modern_di/providers/alias.py @@ -1,3 +1,4 @@ +import enum import typing from modern_di import exceptions, types @@ -16,7 +17,7 @@ def __init__( self, *, source_type: type[types.T_co], - scope: Scope = Scope.APP, + scope: enum.IntEnum = Scope.APP, bound_type: type | None = types.UNSET, # ty: ignore[invalid-parameter-default] ) -> None: super().__init__(scope=scope, bound_type=bound_type if bound_type != types.UNSET else source_type) diff --git a/modern_di/providers/context_provider.py b/modern_di/providers/context_provider.py index 4449a8a..662aff7 100644 --- a/modern_di/providers/context_provider.py +++ b/modern_di/providers/context_provider.py @@ -1,3 +1,4 @@ +import enum import typing from modern_di import types @@ -15,7 +16,7 @@ class ContextProvider(AbstractProvider[types.T_co]): def __init__( self, *, - scope: Scope = Scope.APP, + scope: enum.IntEnum = Scope.APP, context_type: type[types.T_co], bound_type: type | None = types.UNSET, # ty: ignore[invalid-parameter-default] ) -> None: diff --git a/modern_di/providers/factory.py b/modern_di/providers/factory.py index 419fa94..b533881 100644 --- a/modern_di/providers/factory.py +++ b/modern_di/providers/factory.py @@ -1,4 +1,5 @@ import dataclasses +import enum import inspect import typing import warnings @@ -31,7 +32,7 @@ class Factory(AbstractProvider[types.T_co]): def __init__( # noqa: PLR0913 self, *, - scope: Scope = Scope.APP, + scope: enum.IntEnum = Scope.APP, creator: typing.Callable[..., types.T_co], bound_type: type | None = types.UNSET, # ty: ignore[invalid-parameter-default] kwargs: dict[str, typing.Any] | None = None, diff --git a/tests/providers/test_container_provider.py b/tests/providers/test_container_provider.py index 27169d7..094fd69 100644 --- a/tests/providers/test_container_provider.py +++ b/tests/providers/test_container_provider.py @@ -11,7 +11,7 @@ def test_container_provider_direct_resolving() -> None: def test_container_provider_sub_dependency() -> None: def creator(di_container: Container) -> Scope: - return di_container.scope + return Scope(di_container.scope) class MyGroup(Group): factory = providers.Factory(scope=Scope.REQUEST, creator=creator) diff --git a/tests/test_custom_scope.py b/tests/test_custom_scope.py new file mode 100644 index 0000000..fed76a7 --- /dev/null +++ b/tests/test_custom_scope.py @@ -0,0 +1,110 @@ +import dataclasses +import enum + +import pytest + +from modern_di import Container, Group, Scope, providers +from modern_di.exceptions import ( + InvalidChildScopeError, + ScopeNotInitializedError, + ScopeSkippedError, +) + + +class MyScope(enum.IntEnum): + TENANT = 6 + BACKGROUND_JOB = 7 + + +class ConflictingScope(enum.IntEnum): + SAME_AS_APP = 1 + LOWER_THAN_REQUEST = 2 + + +@dataclasses.dataclass(kw_only=True, slots=True) +class TenantService: + pass + + +def test_build_child_at_custom_scope_from_step() -> None: + step_container = Container(scope=Scope.STEP) + tenant_container = step_container.build_child_container(scope=MyScope.TENANT) + assert tenant_container.scope is MyScope.TENANT + assert tenant_container.parent_container is step_container + + +def test_build_child_at_custom_scope_from_app_skips_intermediate() -> None: + app_container = Container() + tenant_container = app_container.build_child_container(scope=MyScope.TENANT) + assert tenant_container.scope is MyScope.TENANT + + +def test_factory_resolves_through_custom_scope_container() -> None: + class TenantGroup(Group): + svc = providers.Factory(scope=MyScope.TENANT, creator=TenantService) + + app_container = Container(groups=[TenantGroup]) + tenant_container = app_container.build_child_container(scope=MyScope.TENANT) + + instance = tenant_container.resolve(TenantService) + assert isinstance(instance, TenantService) + + +def test_resolve_at_custom_scope_from_app_raises_scope_not_initialized() -> None: + class TenantGroup(Group): + svc = providers.Factory(scope=MyScope.TENANT, creator=TenantService) + + app_container = Container(groups=[TenantGroup]) + with pytest.raises(ScopeNotInitializedError, match="TENANT") as exc: + app_container.resolve(TenantService) + assert exc.value.provider_scope is MyScope.TENANT + assert exc.value.container_scope is Scope.APP + + +def test_resolve_app_provider_from_custom_scope_with_skipped_chain() -> None: + # A standalone tenant container that never went through APP -> ... chain + tenant_container = Container(scope=MyScope.TENANT) + app_factory = providers.Factory(creator=lambda: "x") + with pytest.raises(ScopeSkippedError, match="APP"): + tenant_container.resolve_provider(app_factory) + + +def test_invalid_child_scope_uses_parent_enum_for_allowed_list() -> None: + tenant_container = Container(scope=MyScope.TENANT) + with pytest.raises(InvalidChildScopeError) as exc: + tenant_container.build_child_container(scope=MyScope.TENANT) + # allowed_scopes must be drawn from the parent's own enum class (MyScope), + # not the standard Scope enum. + assert exc.value.allowed_scopes == ["BACKGROUND_JOB"] + + +def test_invalid_child_scope_with_conflicting_value() -> None: + app_container = Container() + with pytest.raises(InvalidChildScopeError) as exc: + app_container.build_child_container(scope=ConflictingScope.SAME_AS_APP) + assert exc.value.parent_scope is Scope.APP + assert exc.value.child_scope is ConflictingScope.SAME_AS_APP + + +def test_caching_isolated_across_tenant_containers() -> None: + class TenantGroup(Group): + svc = providers.Factory( + scope=MyScope.TENANT, + creator=TenantService, + cache_settings=providers.CacheSettings(), + ) + + app_container = Container(groups=[TenantGroup]) + tenant_a = app_container.build_child_container(scope=MyScope.TENANT) + tenant_b = app_container.build_child_container(scope=MyScope.TENANT) + + instance_a = tenant_a.resolve(TenantService) + instance_b = tenant_b.resolve(TenantService) + assert instance_a is not instance_b + assert tenant_a.resolve(TenantService) is instance_a + + +def test_auto_derive_within_custom_enum() -> None: + tenant_container = Container(scope=MyScope.TENANT) + bg_container = tenant_container.build_child_container() + assert bg_container.scope is MyScope.BACKGROUND_JOB