Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions modern_di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def resolve_provider(self, provider: "AbstractProvider[types.T]") -> types.T:

return typing.cast(types.T, provider.resolve(self))

def validate_provider(self, provider: "AbstractProvider[types.T]") -> types.T:
return typing.cast(types.T, provider.validate(self))

async def close_async(self) -> None:
if not self.parent_container:
self.overrides_registry.reset_override()
Expand Down
3 changes: 3 additions & 0 deletions modern_di/providers/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ def __init__(

@abc.abstractmethod
def resolve(self, container: "Container") -> typing.Any: ... # noqa: ANN401

@abc.abstractmethod
def validate(self, container: "Container") -> dict[str, typing.Any]: ...
3 changes: 3 additions & 0 deletions modern_di/providers/container_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,8 @@ def __init__(self) -> None:
def resolve(self, container: "Container") -> "Container":
return container

def validate(self, _: "Container") -> dict[str, typing.Any]:
return {"self": self}


container_provider = _ContainerProvider()
8 changes: 5 additions & 3 deletions modern_di/providers/context_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@


class ContextProvider(AbstractProvider[types.T_co]):
__slots__ = [*AbstractProvider.BASE_SLOTS, "_context_type"]
__slots__ = AbstractProvider.BASE_SLOTS

def __init__(self, *, scope: Scope = Scope.APP, context_type: type[types.T_co]) -> None:
super().__init__(scope=scope, bound_type=context_type)
self._context_type = context_type

def validate(self, _: "Container") -> dict[str, typing.Any]:
return {"bound_type": self.bound_type, "self": self}

def resolve(self, container: "Container") -> types.T_co | None:
container = container.find_container(self.scope)
return container.context_registry.find_context(self._context_type)
return container.context_registry.find_context(typing.cast(type[types.T_co], self.bound_type))
17 changes: 17 additions & 0 deletions modern_di/providers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ def _compile_kwargs(self, container: "Container") -> dict[str, typing.Any]:
result.update(self._kwargs)
return result

def validate(self, container: "Container") -> dict[str, typing.Any]:
container = container.find_container(self.scope)
cache_item = container.cache_registry.fetch_cache_item(self)
if cache_item.kwargs is not None:
kwargs = cache_item.kwargs
else:
kwargs = self._compile_kwargs(container)
cache_item.kwargs = kwargs

return {
"bound_type": self.bound_type,
"creator": self._creator,
"self": self,
"kwargs": {k: v.validate(container) if isinstance(v, AbstractProvider) else v for k, v in kwargs.items()},
"cache_settings": self.cache_settings,
}

def resolve(self, container: "Container") -> types.T_co:
container = container.find_container(self.scope)
cache_item = container.cache_registry.fetch_cache_item(self)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dev = [
"pytest",
"pytest-cov",
"pytest-asyncio",
"pytest-repeat",
]
lint = [
"ruff",
Expand Down
1 change: 1 addition & 0 deletions tests/providers/test_container_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def test_container_provider_direct_resolving() -> None:

request_container = app_container.build_child_container(scope=Scope.REQUEST)
assert request_container.resolve_provider(providers.container_provider) is request_container
request_container.validate_provider(providers.container_provider)


def test_container_provider_sub_dependency() -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/providers/test_context_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
def test_context_provider() -> None:
now = datetime.datetime.now(tz=datetime.timezone.utc)
app_container = Container(context={datetime.datetime: now})
app_container.validate_provider(context_provider)
instance1 = app_container.resolve_provider(context_provider)
instance2 = app_container.resolve_provider(context_provider)
assert instance1 is instance2 is now
Expand Down
7 changes: 5 additions & 2 deletions tests/providers/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class MyGroup(Group):
def test_app_factory() -> None:
app_container = Container(groups=[MyGroup])
instance1 = app_container.resolve_provider(MyGroup.app_factory)
app_container.validate_provider(MyGroup.app_factory)
instance2 = app_container.resolve(dependency_type=SimpleCreator)
assert isinstance(instance1, SimpleCreator)
assert isinstance(instance2, SimpleCreator)
Expand All @@ -59,7 +60,7 @@ def test_app_factory_skip_creator_parsing() -> None:
def test_app_factory_unresolvable() -> None:
app_container = Container(groups=[MyGroup])
with pytest.raises(RuntimeError, match="Argument dep1 of type <class 'str'> cannot be resolved"):
app_container.resolve_provider(MyGroup.app_factory_unresolvable)
app_container.validate_provider(MyGroup.app_factory_unresolvable)


def test_func_with_union_factory() -> None:
Expand All @@ -71,14 +72,16 @@ def test_func_with_union_factory() -> None:
def test_func_with_broken_annotation() -> None:
app_container = Container(groups=[MyGroup])
with pytest.raises(RuntimeError, match="Argument dep1 of type None cannot be resolved"):
app_container.resolve_provider(MyGroup.func_with_broken_annotation)
app_container.validate_provider(MyGroup.func_with_broken_annotation)


def test_request_factory() -> None:
app_container = Container(groups=[MyGroup])
request_container = app_container.build_child_container(scope=Scope.REQUEST)
request_container.validate_provider(MyGroup.request_factory)
instance1 = request_container.resolve_provider(MyGroup.request_factory)
instance2 = request_container.resolve_provider(MyGroup.request_factory)
request_container.validate_provider(MyGroup.request_factory)
assert instance1 is not instance2

request_container = app_container.build_child_container(scope=Scope.REQUEST)
Expand Down
Loading