diff --git a/modern_di/container.py b/modern_di/container.py index 65b5234..06ecafa 100644 --- a/modern_di/container.py +++ b/modern_di/container.py @@ -100,9 +100,9 @@ def resolve_provider(self, provider: "AbstractProvider[types.T]") -> types.T: self.overrides_registry.overrides and (override := self.overrides_registry.fetch_override(provider.provider_id)) is not types.UNSET ): - return typing.cast(types.T, override) + return override # ty: ignore[invalid-return-type] - return typing.cast(types.T, provider.resolve(self)) + return provider.resolve(self) def validate_provider(self, provider: "AbstractProvider[types.T]") -> types.T: return typing.cast(types.T, provider.validate(self)) diff --git a/modern_di/providers/factory.py b/modern_di/providers/factory.py index f04496a..cf33f88 100644 --- a/modern_di/providers/factory.py +++ b/modern_di/providers/factory.py @@ -105,8 +105,15 @@ def _ensure_kwargs_cached( ) -> tuple[dict[str, "AbstractProvider[typing.Any]"], dict[str, typing.Any]]: if not cache_item.kwargs_compiled: kwargs = self._compile_kwargs(container) - cache_item.provider_kwargs = {k: v for k, v in kwargs.items() if isinstance(v, AbstractProvider)} - cache_item.static_kwargs = {k: v for k, v in kwargs.items() if not isinstance(v, AbstractProvider)} + provider_kwargs: dict[str, AbstractProvider[typing.Any]] = {} + static_kwargs: dict[str, typing.Any] = {} + for k, v in kwargs.items(): + if isinstance(v, AbstractProvider): + provider_kwargs[k] = v + else: + static_kwargs[k] = v + cache_item.provider_kwargs = provider_kwargs + cache_item.static_kwargs = static_kwargs cache_item.kwargs_compiled = True return cache_item.provider_kwargs, cache_item.static_kwargs @@ -127,6 +134,10 @@ def validate(self, container: "Container") -> dict[str, typing.Any]: def resolve(self, container: "Container") -> types.T_co: container = container.find_container(self.scope) cache_item = container.cache_registry.fetch_cache_item(self) + + if self.cache_settings and cache_item.cache is not None: + return cache_item.cache + provider_kwargs, static_kwargs = self._ensure_kwargs_cached(container, cache_item) resolved_kwargs = dict(static_kwargs) for k, v in provider_kwargs.items(): @@ -135,15 +146,12 @@ def resolve(self, container: "Container") -> types.T_co: if not self.cache_settings: return self._creator(**resolved_kwargs) - if cache_item.cache is not None: - return typing.cast(types.T_co, cache_item.cache) - if container.lock: container.lock.acquire() try: if cache_item.cache is not None: - return typing.cast(types.T_co, cache_item.cache) + return cache_item.cache instance = self._creator(**resolved_kwargs) cache_item.cache = instance