diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..0f58873 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,53 @@ +name: main + +on: + push: + branches: + - main + pull_request: {} + +concurrency: + group: ${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: extractions/setup-just@v2 + - uses: astral-sh/setup-uv@v3 + with: + enable-cache: true + cache-dependency-glob: "**/pyproject.toml" + - run: uv python install 3.10 + - run: just install lint-ci + + pytest: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: + - "3.10" + - "3.11" + - "3.12" + - "3.13" + - "3.14" + steps: + - uses: actions/checkout@v4 + - uses: extractions/setup-just@v2 + - uses: astral-sh/setup-uv@v3 + with: + enable-cache: true + cache-dependency-glob: "**/pyproject.toml" + - run: uv python install ${{ matrix.python-version }} + - run: just install + - run: just test . --cov=. --cov-report xml + - uses: codecov/codecov-action@v4.0.1 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + with: + files: ./coverage.xml + flags: unittests + name: codecov-${{ matrix.python-version }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..b637272 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,17 @@ +name: Publish Package + +on: + release: + types: + - published + +jobs: + publish: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: extractions/setup-just@v2 + - uses: astral-sh/setup-uv@v3 + - run: just publish + env: + PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..068012f --- /dev/null +++ b/.gitignore @@ -0,0 +1,22 @@ +# Generic things +*.pyc +*~ +__pycache__/* +*.swp +*.sqlite3 +*.map +.vscode +.idea +.DS_Store +.env +.mypy_cache +.pytest_cache +.ruff_cache +.coverage +htmlcov/ +coverage.xml +pytest.xml +dist/ +.python-version +.venv +uv.lock diff --git a/Justfile b/Justfile new file mode 100644 index 0000000..fabf8c0 --- /dev/null +++ b/Justfile @@ -0,0 +1,29 @@ +default: install lint test + +install: + uv lock --upgrade + uv sync --all-extras --frozen --group lint + +lint: + uv run eof-fixer . + uv run ruff format + uv run ruff check --fix + uv run mypy . + +lint-ci: + uv run eof-fixer . --check + uv run ruff format --check + uv run ruff check --no-fix + uv run mypy . + +test *args: + uv run --no-sync pytest {{ args }} + +test-branch: + @just test --cov-branch + +publish: + rm -rf dist + uv version $GITHUB_REF_NAME + uv build + uv publish --token $PYPI_TOKEN diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..a176c1b --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 modern-python + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/modern_di_fastapi/__init__.py b/modern_di_fastapi/__init__.py new file mode 100644 index 0000000..5803e33 --- /dev/null +++ b/modern_di_fastapi/__init__.py @@ -0,0 +1,18 @@ +from modern_di_fastapi.main import ( + FromDI, + build_di_container, + fastapi_request_provider, + fastapi_websocket_provider, + fetch_di_container, + setup_di, +) + + +__all__ = [ + "FromDI", + "build_di_container", + "fastapi_request_provider", + "fastapi_websocket_provider", + "fetch_di_container", + "setup_di", +] diff --git a/modern_di_fastapi/main.py b/modern_di_fastapi/main.py new file mode 100644 index 0000000..c044d7f --- /dev/null +++ b/modern_di_fastapi/main.py @@ -0,0 +1,71 @@ +import contextlib +import dataclasses +import typing + +import fastapi +from fastapi.routing import _merge_lifespan_context +from modern_di import Container, Scope, providers +from starlette.requests import HTTPConnection + + +T_co = typing.TypeVar("T_co", covariant=True) + + +fastapi_request_provider = providers.ContextProvider(scope=Scope.REQUEST, context_type=fastapi.Request) +fastapi_websocket_provider = providers.ContextProvider(scope=Scope.SESSION, context_type=fastapi.WebSocket) + + +def fetch_di_container(app_: fastapi.FastAPI) -> Container: + return typing.cast(Container, app_.state.di_container) + + +@contextlib.asynccontextmanager +async def _lifespan_manager(app_: fastapi.FastAPI) -> typing.AsyncIterator[None]: + container = fetch_di_container(app_) + try: + yield + finally: + await container.close_async() + + +def setup_di(app: fastapi.FastAPI, container: Container) -> Container: + app.state.di_container = container + container.providers_registry.add_providers(fastapi_request_provider, fastapi_websocket_provider) + old_lifespan_manager = app.router.lifespan_context + app.router.lifespan_context = _merge_lifespan_context( + old_lifespan_manager, + _lifespan_manager, + ) + return container + + +async def build_di_container(connection: HTTPConnection) -> typing.AsyncIterator[Container]: + context: dict[type[typing.Any], typing.Any] = {} + scope: Scope | None = None + if isinstance(connection, fastapi.Request): + scope = Scope.REQUEST + context[fastapi.Request] = connection + elif isinstance(connection, fastapi.WebSocket): + context[fastapi.WebSocket] = connection + scope = Scope.SESSION + container = fetch_di_container(connection.app).build_child_container(context=context, scope=scope) + try: + yield container + finally: + await container.close_async() + + +@dataclasses.dataclass(slots=True, frozen=True) +class Dependency(typing.Generic[T_co]): + dependency: providers.AbstractProvider[T_co] | type[T_co] + + async def __call__( + self, request_container: typing.Annotated[Container, fastapi.Depends(build_di_container)] + ) -> T_co: + if isinstance(self.dependency, providers.AbstractProvider): + return request_container.resolve_provider(self.dependency) + return request_container.resolve(dependency_type=self.dependency) + + +def FromDI(dependency: providers.AbstractProvider[T_co] | type[T_co], *, use_cache: bool = True) -> T_co: # noqa: N802 + return typing.cast(T_co, fastapi.Depends(dependency=Dependency(dependency), use_cache=use_cache)) diff --git a/modern_di_fastapi/py.typed b/modern_di_fastapi/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8eb41d9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,83 @@ +[project] +name = "modern-di-fastapi" +description = "Modern-DI integration for FastAPI" +authors = [{ name = "Artur Shiriev", email = "me@shiriev.ru" }] +requires-python = ">=3.10,<4" +license = "MIT" +readme = "README.md" +keywords = ["DI", "dependency injector", "ioc-container", "FastAPI", "python"] +classifiers = [ + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Typing :: Typed", + "Topic :: Software Development :: Libraries", +] +dependencies = ["fastapi>=0.100", "modern-di>=2,<3"] +version = "0" + +[project.urls] +repository = "https://github.com/modern-python/modern-di-fastapi" +docs = "https://modern-di.readthedocs.io" + +[dependency-groups] +dev = [ + "pytest", + "pytest-cov", + "pytest-asyncio", + "httpx", +] +lint = [ + "mypy", + "ruff", + "eof-fixer", + "typing-extensions", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["modern_di_fastapi"] + +[tool.mypy] +python_version = "3.10" +strict = true + +[tool.ruff] +fix = false +unsafe-fixes = true +line-length = 120 +target-version = "py310" + +[tool.ruff.format] +docstring-code-format = true + +[tool.ruff.lint] +select = ["ALL"] +ignore = [ + "D1", # allow missing docstrings + "S101", # allow asserts + "TCH", # ignore flake8-type-checking + "FBT", # allow boolean args + "D203", # "one-blank-line-before-class" conflicting with D211 + "D213", # "multi-line-summary-second-line" conflicting with D212 + "COM812", # flake8-commas "Trailing comma missing" + "ISC001", # flake8-implicit-str-concat + "G004", # allow f-strings in logging +] +isort.lines-after-imports = 2 +isort.no-lines-before = ["standard-library", "local-folder"] + +[tool.pytest.ini_options] +addopts = "--cov=. --cov-report term-missing" +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" + +[tool.coverage.report] +exclude_also = [ + "if typing.TYPE_CHECKING:", +] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5eaf2aa --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,23 @@ +import typing + +import fastapi +import modern_di +import pytest +from starlette.testclient import TestClient + +import modern_di_fastapi +from tests.dependencies import Dependencies + + +@pytest.fixture +async def app() -> fastapi.FastAPI: + app_ = fastapi.FastAPI() + container = modern_di.Container(groups=[Dependencies]) + modern_di_fastapi.setup_di(app_, container=container) + return app_ + + +@pytest.fixture +def client(app: fastapi.FastAPI) -> typing.Iterator[TestClient]: + with TestClient(app=app) as test_client: + yield test_client diff --git a/tests/dependencies.py b/tests/dependencies.py new file mode 100644 index 0000000..9044f66 --- /dev/null +++ b/tests/dependencies.py @@ -0,0 +1,33 @@ +import dataclasses + +import fastapi +from modern_di import Group, Scope, providers + + +@dataclasses.dataclass(kw_only=True, slots=True) +class SimpleCreator: + dep1: str + + +@dataclasses.dataclass(kw_only=True, slots=True) +class DependentCreator: + dep1: SimpleCreator + + +def fetch_method_from_request(request: fastapi.Request) -> str: + assert isinstance(request, fastapi.Request) + return request.method + + +def fetch_url_from_websocket(websocket: fastapi.WebSocket) -> str: + assert isinstance(websocket, fastapi.WebSocket) + return websocket.url.path + + +class Dependencies(Group): + app_factory = providers.Factory(creator=SimpleCreator, kwargs={"dep1": "original"}) + session_factory = providers.Factory(scope=Scope.SESSION, creator=DependentCreator, bound_type=None) + request_factory = providers.Factory(scope=Scope.REQUEST, creator=DependentCreator, bound_type=None) + action_factory = providers.Factory(scope=Scope.ACTION, creator=DependentCreator, bound_type=None) + request_method = providers.Factory(scope=Scope.REQUEST, creator=fetch_method_from_request, bound_type=None) + websocket_path = providers.Factory(scope=Scope.SESSION, creator=fetch_url_from_websocket, bound_type=None) diff --git a/tests/test_routes.py b/tests/test_routes.py new file mode 100644 index 0000000..bc1f6c0 --- /dev/null +++ b/tests/test_routes.py @@ -0,0 +1,50 @@ +import typing + +import fastapi +from modern_di import Container +from starlette import status +from starlette.testclient import TestClient + +from modern_di_fastapi import FromDI, build_di_container +from tests.dependencies import Dependencies, DependentCreator, SimpleCreator + + +def test_factories(client: TestClient, app: fastapi.FastAPI) -> None: + @app.get("/") + async def read_root( + app_factory_instance: typing.Annotated[SimpleCreator, FromDI(SimpleCreator)], + request_factory_instance: typing.Annotated[DependentCreator, FromDI(Dependencies.request_factory)], + ) -> None: + assert isinstance(app_factory_instance, SimpleCreator) + assert isinstance(request_factory_instance, DependentCreator) + assert request_factory_instance.dep1 is not app_factory_instance + + response = client.get("/") + assert response.status_code == status.HTTP_200_OK + assert response.json() is None + + +def test_context_provider(client: TestClient, app: fastapi.FastAPI) -> None: + @app.get("/") + async def read_root( + method: typing.Annotated[str, FromDI(Dependencies.request_method)], + ) -> None: + assert method == "GET" + + response = client.get("/") + assert response.status_code == status.HTTP_200_OK + assert response.json() is None + + +def test_factories_action_scope(client: TestClient, app: fastapi.FastAPI) -> None: + @app.get("/") + async def read_root( + request_container: typing.Annotated[Container, fastapi.Depends(build_di_container)], + ) -> None: + action_container = request_container.build_child_container() + action_factory_instance = action_container.resolve_provider(Dependencies.action_factory) + assert isinstance(action_factory_instance, DependentCreator) + + response = client.get("/") + assert response.status_code == status.HTTP_200_OK + assert response.json() is None diff --git a/tests/test_websockets.py b/tests/test_websockets.py new file mode 100644 index 0000000..3aaad18 --- /dev/null +++ b/tests/test_websockets.py @@ -0,0 +1,64 @@ +import typing + +import fastapi +from modern_di import Container +from starlette.testclient import TestClient + +from modern_di_fastapi import FromDI, build_di_container +from tests.dependencies import Dependencies, DependentCreator, SimpleCreator + + +async def test_factories(client: TestClient, app: fastapi.FastAPI) -> None: + @app.websocket("/ws") + async def websocket_endpoint( + websocket: fastapi.WebSocket, + app_factory_instance: typing.Annotated[SimpleCreator, FromDI(SimpleCreator)], + session_factory_instance: typing.Annotated[DependentCreator, FromDI(Dependencies.session_factory)], + ) -> None: + assert isinstance(app_factory_instance, SimpleCreator) + assert isinstance(session_factory_instance, DependentCreator) + assert session_factory_instance.dep1 is not app_factory_instance + + await websocket.accept() + await websocket.send_text("test") + await websocket.close() + + with client.websocket_connect("/ws") as websocket: + data = websocket.receive_text() + assert data == "test" + + +async def test_factories_request_scope(client: TestClient, app: fastapi.FastAPI) -> None: + @app.websocket("/ws") + async def websocket_endpoint( + websocket: fastapi.WebSocket, + session_container: typing.Annotated[Container, fastapi.Depends(build_di_container)], + ) -> None: + request_container = session_container.build_child_container() + request_factory_instance = request_container.resolve_provider(Dependencies.request_factory) + assert isinstance(request_factory_instance, DependentCreator) + + await websocket.accept() + await websocket.send_text("test") + await websocket.close() + + with client.websocket_connect("/ws") as websocket: + data = websocket.receive_text() + assert data == "test" + + +async def test_context_provider(client: TestClient, app: fastapi.FastAPI) -> None: + @app.websocket("/ws") + async def websocket_endpoint( + websocket: fastapi.WebSocket, + path: typing.Annotated[str, FromDI(Dependencies.websocket_path)], + ) -> None: + assert path == "/ws" + + await websocket.accept() + await websocket.send_text("test") + await websocket.close() + + with client.websocket_connect("/ws") as websocket: + data = websocket.receive_text() + assert data == "test"