From 42a553839289eaf0272814cf3a395b165d34757b Mon Sep 17 00:00:00 2001 From: Artur Shiriev Date: Mon, 23 Feb 2026 14:22:06 +0300 Subject: [PATCH] add validate method to providers --- modern_di/container.py | 3 +++ modern_di/providers/abstract.py | 3 +++ modern_di/providers/container_provider.py | 3 +++ modern_di/providers/context_provider.py | 8 +++++--- modern_di/providers/factory.py | 17 +++++++++++++++++ pyproject.toml | 1 + tests/providers/test_container_provider.py | 1 + tests/providers/test_context_provider.py | 1 + tests/providers/test_factory.py | 7 +++++-- 9 files changed, 39 insertions(+), 5 deletions(-) diff --git a/modern_di/container.py b/modern_di/container.py index e1f8d3e..64374d1 100644 --- a/modern_di/container.py +++ b/modern_di/container.py @@ -104,6 +104,9 @@ def resolve_provider(self, provider: "AbstractProvider[types.T]") -> types.T: return typing.cast(types.T, provider.resolve(self)) + def validate_provider(self, provider: "AbstractProvider[types.T]") -> types.T: + return typing.cast(types.T, provider.validate(self)) + async def close_async(self) -> None: if not self.parent_container: self.overrides_registry.reset_override() diff --git a/modern_di/providers/abstract.py b/modern_di/providers/abstract.py index 4224bd1..bed820f 100644 --- a/modern_di/providers/abstract.py +++ b/modern_di/providers/abstract.py @@ -25,3 +25,6 @@ def __init__( @abc.abstractmethod def resolve(self, container: "Container") -> typing.Any: ... # noqa: ANN401 + + @abc.abstractmethod + def validate(self, container: "Container") -> dict[str, typing.Any]: ... diff --git a/modern_di/providers/container_provider.py b/modern_di/providers/container_provider.py index 3abe243..0887948 100644 --- a/modern_di/providers/container_provider.py +++ b/modern_di/providers/container_provider.py @@ -17,5 +17,8 @@ def __init__(self) -> None: def resolve(self, container: "Container") -> "Container": return container + def validate(self, _: "Container") -> dict[str, typing.Any]: + return {"self": self} + container_provider = _ContainerProvider() diff --git a/modern_di/providers/context_provider.py b/modern_di/providers/context_provider.py index 0f6a317..fe0f0d9 100644 --- a/modern_di/providers/context_provider.py +++ b/modern_di/providers/context_provider.py @@ -10,12 +10,14 @@ class ContextProvider(AbstractProvider[types.T_co]): - __slots__ = [*AbstractProvider.BASE_SLOTS, "_context_type"] + __slots__ = AbstractProvider.BASE_SLOTS def __init__(self, *, scope: Scope = Scope.APP, context_type: type[types.T_co]) -> None: super().__init__(scope=scope, bound_type=context_type) - self._context_type = context_type + + def validate(self, _: "Container") -> dict[str, typing.Any]: + return {"bound_type": self.bound_type, "self": self} def resolve(self, container: "Container") -> types.T_co | None: container = container.find_container(self.scope) - return container.context_registry.find_context(self._context_type) + return container.context_registry.find_context(typing.cast(type[types.T_co], self.bound_type)) diff --git a/modern_di/providers/factory.py b/modern_di/providers/factory.py index 8e118c0..1cad085 100644 --- a/modern_di/providers/factory.py +++ b/modern_di/providers/factory.py @@ -76,6 +76,23 @@ def _compile_kwargs(self, container: "Container") -> dict[str, typing.Any]: result.update(self._kwargs) return result + def validate(self, container: "Container") -> dict[str, typing.Any]: + container = container.find_container(self.scope) + cache_item = container.cache_registry.fetch_cache_item(self) + if cache_item.kwargs is not None: + kwargs = cache_item.kwargs + else: + kwargs = self._compile_kwargs(container) + cache_item.kwargs = kwargs + + return { + "bound_type": self.bound_type, + "creator": self._creator, + "self": self, + "kwargs": {k: v.validate(container) if isinstance(v, AbstractProvider) else v for k, v in kwargs.items()}, + "cache_settings": self.cache_settings, + } + def resolve(self, container: "Container") -> types.T_co: container = container.find_container(self.scope) cache_item = container.cache_registry.fetch_cache_item(self) diff --git a/pyproject.toml b/pyproject.toml index 18b9cb8..7f87b0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dev = [ "pytest", "pytest-cov", "pytest-asyncio", + "pytest-repeat", ] lint = [ "ruff", diff --git a/tests/providers/test_container_provider.py b/tests/providers/test_container_provider.py index 27169d7..57521ea 100644 --- a/tests/providers/test_container_provider.py +++ b/tests/providers/test_container_provider.py @@ -7,6 +7,7 @@ def test_container_provider_direct_resolving() -> None: request_container = app_container.build_child_container(scope=Scope.REQUEST) assert request_container.resolve_provider(providers.container_provider) is request_container + request_container.validate_provider(providers.container_provider) def test_container_provider_sub_dependency() -> None: diff --git a/tests/providers/test_context_provider.py b/tests/providers/test_context_provider.py index 4bb1047..77be649 100644 --- a/tests/providers/test_context_provider.py +++ b/tests/providers/test_context_provider.py @@ -10,6 +10,7 @@ def test_context_provider() -> None: now = datetime.datetime.now(tz=datetime.timezone.utc) app_container = Container(context={datetime.datetime: now}) + app_container.validate_provider(context_provider) instance1 = app_container.resolve_provider(context_provider) instance2 = app_container.resolve_provider(context_provider) assert instance1 is instance2 is now diff --git a/tests/providers/test_factory.py b/tests/providers/test_factory.py index 4eab86d..69ed90e 100644 --- a/tests/providers/test_factory.py +++ b/tests/providers/test_factory.py @@ -42,6 +42,7 @@ class MyGroup(Group): def test_app_factory() -> None: app_container = Container(groups=[MyGroup]) instance1 = app_container.resolve_provider(MyGroup.app_factory) + app_container.validate_provider(MyGroup.app_factory) instance2 = app_container.resolve(dependency_type=SimpleCreator) assert isinstance(instance1, SimpleCreator) assert isinstance(instance2, SimpleCreator) @@ -59,7 +60,7 @@ def test_app_factory_skip_creator_parsing() -> None: def test_app_factory_unresolvable() -> None: app_container = Container(groups=[MyGroup]) with pytest.raises(RuntimeError, match="Argument dep1 of type cannot be resolved"): - app_container.resolve_provider(MyGroup.app_factory_unresolvable) + app_container.validate_provider(MyGroup.app_factory_unresolvable) def test_func_with_union_factory() -> None: @@ -71,14 +72,16 @@ def test_func_with_union_factory() -> None: def test_func_with_broken_annotation() -> None: app_container = Container(groups=[MyGroup]) with pytest.raises(RuntimeError, match="Argument dep1 of type None cannot be resolved"): - app_container.resolve_provider(MyGroup.func_with_broken_annotation) + app_container.validate_provider(MyGroup.func_with_broken_annotation) def test_request_factory() -> None: app_container = Container(groups=[MyGroup]) request_container = app_container.build_child_container(scope=Scope.REQUEST) + request_container.validate_provider(MyGroup.request_factory) instance1 = request_container.resolve_provider(MyGroup.request_factory) instance2 = request_container.resolve_provider(MyGroup.request_factory) + request_container.validate_provider(MyGroup.request_factory) assert instance1 is not instance2 request_container = app_container.build_child_container(scope=Scope.REQUEST)