From 5fec14237ed3578c2b2c9e8522f8a5c5bd0d11ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Tue, 31 Mar 2026 13:02:38 +0200 Subject: [PATCH 1/5] feature(datashare-python): implement dependency discovery --- datashare-python/tests/test_discovery.py | 75 ++++++++++++++++++- worker-template/pyproject.toml | 3 + .../worker_template/dependencies.py | 10 +++ 3 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 worker-template/worker_template/dependencies.py diff --git a/datashare-python/tests/test_discovery.py b/datashare-python/tests/test_discovery.py index 47132b5..871a0fc 100644 --- a/datashare-python/tests/test_discovery.py +++ b/datashare-python/tests/test_discovery.py @@ -1,5 +1,15 @@ +import re +from importlib.metadata import EntryPoints +from unittest.mock import MagicMock + +import datashare_python import pytest -from datashare_python.discovery import discover_activities, discover_workflows +from _pytest.monkeypatch import MonkeyPatch +from datashare_python.discovery import ( + discover_activities, + discover_dependencies, + discover_workflows, +) @pytest.mark.parametrize( @@ -42,3 +52,66 @@ def test_discover_activities(names: list[str], expected_activities: set[str]) -> activities = {act_name for act_name, _ in discover_activities(names)} # Then assert activities == expected_activities + + +@pytest.mark.parametrize("name", ["base", None]) +def test_discover_dependencies(name: str | None) -> None: + # When + deps = discover_dependencies(name) + # Then + expected_deps = [ + "set_loggers", + "set_event_loop", + "set_es_client", + "set_temporal_client", + ] + assert [d.__name__ for d in deps] == expected_deps + + +def test_discover_dependencies_should_raise_for_unknown_dep() -> None: + # Given + unknown_dep = "unknown_dep" + # When/Then + expected = ( + 'failed to find dependency for name "unknown_dep", ' + "available dependencies: ['base']" + ) + with pytest.raises(LookupError, match=re.escape(expected)): + discover_dependencies(unknown_dep) + + +def test_discover_dependencies_should_raise_for_conflicting_deps( + monkeypatch: MonkeyPatch, +) -> None: + # Given + def mocked_entry_points(name: str, group: str) -> EntryPoints: # noqa: ARG001 + entry_points = MagicMock() + entry_points.__len__.return_value = 2 + return entry_points + + monkeypatch.setattr(datashare_python.discovery, "entry_points", mocked_entry_points) + # When/Then + expected = "found multiple dependencies for name" + with pytest.raises(ValueError, match=re.escape(expected)): + discover_dependencies(name=None) + + +def test_discover_dependencies_should_raise_for_multiple_entry_points( + monkeypatch: MonkeyPatch, +) -> None: + # Given + def mocked_entry_points(name: str, group: str) -> EntryPoints: # noqa: ARG001 + ep = MagicMock() + ep.load.return_value = {"a": [], "b": []} + entry_points = MagicMock() + entry_points.__getitem__.return_value = ep + return entry_points + + monkeypatch.setattr(datashare_python.discovery, "entry_points", mocked_entry_points) + # When/Then + expected = ( + 'dependency registry contains multiples entries "a", "b",' + " please select one by providing a name" + ) + with pytest.raises(ValueError, match=re.escape(expected)): + discover_dependencies(name=None) diff --git a/worker-template/pyproject.toml b/worker-template/pyproject.toml index 2264de4..05b9bf3 100644 --- a/worker-template/pyproject.toml +++ b/worker-template/pyproject.toml @@ -28,6 +28,9 @@ workflows = "worker_template.workflows:WORKFLOWS" [project.entry-points."datashare.activities"] activities = "worker_template.activities:ACTIVITIES" +[project.entry-points."datashare.dependencies"] +dependencies = "worker_template.dependencies:DEPENDENCIES" + [tool.uv.sources] torch = [ { index = "pytorch-cpu" }, diff --git a/worker-template/worker_template/dependencies.py b/worker-template/worker_template/dependencies.py new file mode 100644 index 0000000..64dd6d3 --- /dev/null +++ b/worker-template/worker_template/dependencies.py @@ -0,0 +1,10 @@ +from datashare_python.dependencies import ( + set_es_client, + set_event_loop, + set_loggers, + set_temporal_client, +) + +BASE = [set_loggers, set_event_loop, set_es_client, set_temporal_client] + +DEPENDENCIES = {"base": BASE} From edc8e86405e177da8f93b9a22e35d6e9b10ff546 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Tue, 31 Mar 2026 14:44:58 +0200 Subject: [PATCH 2/5] feature(datashare-python): start worker and run dependencies --- .../datashare_python/cli/worker.py | 94 +++++++++++++------ datashare-python/datashare_python/conftest.py | 8 ++ .../datashare_python/discovery.py | 94 ++++++++++++++++++- datashare-python/datashare_python/worker.py | 13 +++ datashare-python/pyproject.toml | 1 + datashare-python/tests/cli/test_worker.py | 12 ++- datashare-python/uv.lock | 2 + worker-template/worker_template/activities.py | 15 ++- 8 files changed, 200 insertions(+), 39 deletions(-) diff --git a/datashare-python/datashare_python/cli/worker.py b/datashare-python/datashare_python/cli/worker.py index 4e6ac42..4d8672d 100644 --- a/datashare-python/datashare_python/cli/worker.py +++ b/datashare-python/datashare_python/cli/worker.py @@ -1,12 +1,19 @@ +import asyncio import logging +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from pathlib import Path from typing import Annotated import typer +import yaml +from datashare_python.config import WorkerConfig from datashare_python.constants import DEFAULT_NAMESPACE, DEFAULT_TEMPORAL_ADDRESS -from datashare_python.discovery import discover_activities, discover_workflows +from datashare_python.dependencies import with_dependencies +from datashare_python.discovery import discover, discover_activities, discover_workflows from datashare_python.types_ import TemporalClient -from datashare_python.worker import datashare_worker +from datashare_python.worker import create_worker_id, datashare_worker from .utils import AsyncTyper @@ -20,6 +27,12 @@ _START_WORKER_WORKFLOWS_HELP = "workflow names run by the worker (supports regexes)" _START_WORKER_ACTIVITIES_HELP = "activity names run by the worker (supports regexes)" +_START_WORKER_DEPS_HELP = "worker lifetime dependencies name in the registry" +_START_WORKER_WORKER_ID_PREFIX_HELP = "worker ID prefix" +_START_WORKER_CONFIG_PATH_HELP = ( + "path to a worker config YAML file," + " if not provided will load worker configuration from env variables" +) _WORKER_QUEUE_HELP = "worker task queue" _WORKER_MAX_ACTIVITIES_HELP = ( "maximum number of concurrent activities/tasks" @@ -73,6 +86,18 @@ async def start( workflows: Annotated[list[str], typer.Option(help=_START_WORKER_WORKFLOWS_HELP)], activities: Annotated[list[str], typer.Option(help=_START_WORKER_ACTIVITIES_HELP)], queue: Annotated[str, typer.Option("--queue", "-q", help=_WORKER_QUEUE_HELP)], + dependencies: Annotated[ + str | None, typer.Option(help=_START_WORKER_DEPS_HELP) + ] = None, + config_path: Annotated[ + Path | None, + typer.Option( + "--config-path", "--config", "-c", help=_START_WORKER_CONFIG_PATH_HELP + ), + ] = None, + worker_id_prefix: Annotated[ + str | None, typer.Option(help=_START_WORKER_WORKER_ID_PREFIX_HELP) + ] = None, temporal_address: Annotated[ str, typer.Option("--temporal-address", "-a", help=_TEMPORAL_URL_HELP) ] = DEFAULT_TEMPORAL_ADDRESS, @@ -83,32 +108,43 @@ async def start( int, typer.Option("--max-activities", help=_WORKER_MAX_ACTIVITIES_HELP) ] = 1, ) -> None: - wf_names, wfs = zip(*discover_workflows(workflows), strict=False) - registered = "" - if wf_names: - n_wfs = len(wf_names) - registered += ( - f"- {n_wfs} workflow{'s' if n_wfs > 1 else ''}: {','.join(wf_names)}" + if config_path is not None: + with config_path.open() as f: + bootstrap_config = WorkerConfig.model_validate( + yaml.load(f, Loader=yaml.Loader) + ) + else: + bootstrap_config = WorkerConfig() + wfs, acts, deps = discover(workflows, act_names=activities, deps_name=dependencies) + worker_id = create_worker_id(worker_id_prefix or "worker") + event_loop = asyncio.get_event_loop() + deps_cm = ( + with_dependencies( + deps, + worker_config=bootstrap_config, + worker_id=worker_id, + event_loop=event_loop, ) - act_names, acts = zip(*discover_activities(activities), strict=False) - if act_names: - if registered: - registered += "\n" - i = len(act_names) - registered += f"- {i} activit{'ies' if i > 1 else 'y'}: {','.join(act_names)}" - if not acts and not wfs: - raise ValueError("Couldn't find any registered activity or workflow.") - logger.info("Starting datashare worker running:\n%s", registered) - client = await TemporalClient.connect(temporal_address, namespace=namespace) - worker = datashare_worker( - client, - workflows=wfs, - activities=acts, - task_queue=queue, - max_concurrent_activities=max_concurrent_activities, + if deps + else _do_nothing_cm ) - try: - await worker.run() - except Exception as e: # noqa: BLE001 - await worker.shutdown() - raise e + async with deps_cm: + client = await TemporalClient.connect(temporal_address, namespace=namespace) + worker = datashare_worker( + client, + worker_id, + workflows=wfs, + activities=acts, + task_queue=queue, + max_concurrent_activities=max_concurrent_activities, + ) + try: + await worker.run() + except Exception as e: # noqa: BLE001 + await worker.shutdown() + raise e + + +@asynccontextmanager +async def _do_nothing_cm() -> AsyncGenerator[None, None]: + yield diff --git a/datashare-python/datashare_python/conftest.py b/datashare-python/datashare_python/conftest.py index 31fadb1..a96eee7 100644 --- a/datashare-python/datashare_python/conftest.py +++ b/datashare-python/datashare_python/conftest.py @@ -1,6 +1,7 @@ import asyncio from asyncio import AbstractEventLoop from collections.abc import AsyncGenerator, Generator, Iterator, Sequence +from pathlib import Path import aiohttp import pytest @@ -102,6 +103,13 @@ def test_worker_config() -> WorkerConfig: ) +@pytest.fixture +def test_worker_config_path(test_worker_config: WorkerConfig, tmpdir: Path) -> Path: + config_path = Path(tmpdir) / "config.json" + config_path.write_text(test_worker_config.model_dump_json()) + return config_path + + @pytest.fixture(scope="session") async def worker_lifetime_deps( event_loop: AbstractEventLoop, diff --git a/datashare-python/datashare_python/discovery.py b/datashare-python/datashare_python/discovery.py index 659dce8..be00611 100644 --- a/datashare-python/datashare_python/discovery.py +++ b/datashare-python/datashare_python/discovery.py @@ -1,16 +1,69 @@ +import logging import re from collections.abc import Callable, Iterable from importlib.metadata import entry_points +from .types_ import ContextManagerFactory from .utils import ActivityWithProgress +logger = logging.getLogger(__name__) + Activity = ActivityWithProgress | Callable | type +_DEPENDENCIES = "dependencies" _WORKFLOW_GROUPS = "datashare.workflows" _ACTIVITIES_GROUPS = "datashare.activities" +_DEPENDENCIES_GROUPS = "datashare.dependencies" + +_RegisteredWorkflow = tuple[str, type] +_RegisteredActivity = tuple[str, Activity] +_Dependencies = list[ContextManagerFactory] +_Discovery = tuple[ + Iterable[_RegisteredWorkflow] | None, + Iterable[_RegisteredActivity] | None, + _Dependencies | None, +] -def discover_workflows(names: list[str]) -> Iterable[tuple[str, type]]: +def discover( + wf_names: list[str] | None, *, act_names: list[str] | None, deps_name: str | None +) -> _Discovery: + discovered = "" + wfs = None + if wf_names is not None: + wf_names, wfs = zip(*discover_workflows(wf_names), strict=True) + if wf_names: + n_wfs = len(wf_names) + discovered += ( + f"- {n_wfs} workflow{'s' if n_wfs > 1 else ''}: {', '.join(wf_names)}" + ) + acts = None + if act_names is not None: + act_names, acts = zip(*discover_activities(act_names), strict=True) + if act_names: + if discovered: + discovered += "\n" + n_acts = len(act_names) + discovered += ( + f"- {n_acts} activit{'ies' if n_acts > 1 else 'y'}:" + f" {', '.join(act_names)}" + ) + if not acts and not wfs: + raise ValueError("Couldn't find any registered activity or workflow.") + deps = discover_dependencies(deps_name) + if deps: + n_deps = len(deps) + discovered += "\n" + deps_names = (d.__name__ for d in deps) + discovered += ( + f"- {n_deps} dependenc{'ies' if n_deps > 1 else 'y'}:" + f" {', '.join(deps_names)}" + ) + logger.info("discovered:\n%s", discovered) + return wfs, acts, deps + + +def discover_workflows(names: list[str]) -> Iterable[_RegisteredWorkflow]: pattern = None if not names else re.compile(rf"^{'|'.join(names)}$") impls = entry_points(group=_WORKFLOW_GROUPS) for wf_impls in impls: @@ -24,7 +77,7 @@ def discover_workflows(names: list[str]) -> Iterable[tuple[str, type]]: yield wf_name, wf_impl -def discover_activities(names: list[str]) -> Iterable[tuple[str, Activity]]: +def discover_activities(names: list[str]) -> Iterable[_RegisteredActivity]: pattern = None if not names else re.compile(rf"^{'|'.join(names)}$") impls = entry_points(group=_ACTIVITIES_GROUPS) for act_impls in impls: @@ -38,6 +91,43 @@ def discover_activities(names: list[str]) -> Iterable[tuple[str, Activity]]: yield act_name, act_impl +def discover_dependencies(name: str | None) -> _Dependencies | None: + impls = entry_points(name=_DEPENDENCIES, group=_DEPENDENCIES_GROUPS) + if not impls: + if name is None: + return None + available_impls = entry_points(group=_DEPENDENCIES_GROUPS) + msg = ( + f'failed to find dependency: "{name}", ' + f"available dependencies: {available_impls}" + ) + raise LookupError(msg) + if len(impls) > 1: + msg = f'found multiple dependencies for name "{name}": {impls}' + raise ValueError(msg) + deps_registry = impls[_DEPENDENCIES].load() + if name: + try: + return deps_registry[name] + except KeyError as e: + available = list(deps_registry) + msg = ( + f'failed to find dependency for name "{name}", available dependencies: ' + f"{available}" + ) + raise LookupError(msg) from e + if not deps_registry: + raise ValueError("empty dependency registry !") + if len(deps_registry) > 1: + available = ", ".join('"' + d + '"' for d in deps_registry) + msg = ( + f"dependency registry contains multiples entries {available}," + f" please select one by providing a name" + ) + raise ValueError(msg) + return next(iter(deps_registry.values())) + + def _parse_wf_name(wf_type: type) -> str: if not isinstance(wf_type, type): msg = ( diff --git a/datashare-python/datashare_python/worker.py b/datashare-python/datashare_python/worker.py index 9a521a8..6701269 100644 --- a/datashare-python/datashare_python/worker.py +++ b/datashare-python/datashare_python/worker.py @@ -1,4 +1,7 @@ import logging +import os +import socket +import threading from concurrent.futures import ThreadPoolExecutor from temporalio.worker import PollerBehaviorSimpleMaximum, Worker @@ -30,6 +33,7 @@ def datashare_worker( client: TemporalClient, + worker_id: str, *, workflows: list[type] | None = None, activities: list[Activity] | None = None, @@ -60,6 +64,7 @@ def datashare_worker( return Worker( client, + identity=worker_id, workflows=workflows, activities=activities, task_queue=task_queue, @@ -72,3 +77,11 @@ def datashare_worker( # several of them workflow_task_poller_behavior=PollerBehaviorSimpleMaximum(5), ) + + +def create_worker_id(prefix: str) -> str: + pid = os.getpid() + threadid = threading.get_ident() + hostname = socket.gethostname() + # TODO: this might not be unique when using asyncio + return f"{prefix}-{hostname}-{pid}-{threadid}" diff --git a/datashare-python/pyproject.toml b/datashare-python/pyproject.toml index cd554cb..d02a94d 100644 --- a/datashare-python/pyproject.toml +++ b/datashare-python/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "datashare-worker-template[ml]~=0.1", "tomlkit~=0.14.0", "hatchling~=1.27.0", + "pyyaml~=6.0", ] [project.urls] diff --git a/datashare-python/tests/cli/test_worker.py b/datashare-python/tests/cli/test_worker.py index 9d195d8..b426cd9 100644 --- a/datashare-python/tests/cli/test_worker.py +++ b/datashare-python/tests/cli/test_worker.py @@ -1,3 +1,5 @@ +from pathlib import Path + from _pytest.capture import CaptureFixture from _pytest.monkeypatch import MonkeyPatch from datashare_python.cli import cli_app @@ -10,12 +12,13 @@ async def _mock_worker_run(self) -> None: # noqa: ANN001 async def test_start_workers( - worker_lifetime_deps, # noqa: ANN001, ARG001 typer_asyncio_patch, # noqa: ANN001, ARG001 + test_worker_config_path: Path, monkeypatch: MonkeyPatch, capsys: CaptureFixture[str], ) -> None: # Given + config_path = test_worker_config_path runner = CliRunner(mix_stderr=False) monkeypatch.setattr(Worker, "run", _mock_worker_run) with capsys.disabled(): @@ -27,6 +30,8 @@ async def test_start_workers( "start", "--queue", "cpu", + "-c", + str(config_path), "--activities", "ping", "--activities", @@ -40,7 +45,8 @@ async def test_start_workers( ) # Then assert result.exit_code == 0 - expected = """Starting datashare worker running: + expected = """discovered: - 1 workflow: ping -- 1 activity: create-translation-batches""" +- 1 activity: create-translation-batches +- 4 dependencies: set_loggers, set_event_loop, set_es_client, set_temporal_client""" assert expected in result.stderr diff --git a/datashare-python/uv.lock b/datashare-python/uv.lock index c951da9..a63a3da 100644 --- a/datashare-python/uv.lock +++ b/datashare-python/uv.lock @@ -404,6 +404,7 @@ dependencies = [ { name = "icij-common", extra = ["elasticsearch"] }, { name = "nest-asyncio" }, { name = "python-json-logger" }, + { name = "pyyaml" }, { name = "temporalio" }, { name = "tomlkit" }, { name = "typer" }, @@ -433,6 +434,7 @@ requires-dist = [ { name = "icij-common", extras = ["elasticsearch"], specifier = "~=0.7.3" }, { name = "nest-asyncio", specifier = "~=1.6.0" }, { name = "python-json-logger", specifier = "~=4.0.0" }, + { name = "pyyaml", specifier = "~=6.0" }, { name = "temporalio", specifier = "~=1.23.0" }, { name = "tomlkit", specifier = "~=0.14.0" }, { name = "typer", specifier = "~=0.15.4" }, diff --git a/worker-template/worker_template/activities.py b/worker-template/worker_template/activities.py index 3f99b13..cfa94bb 100644 --- a/worker-template/worker_template/activities.py +++ b/worker-template/worker_template/activities.py @@ -2,6 +2,7 @@ import logging from collections.abc import AsyncGenerator, Generator, Iterable from functools import partial +from typing import TYPE_CHECKING from aiostream.stream import chain from datashare_python.objects import Document @@ -40,7 +41,9 @@ ) from temporalio import activity from temporalio.client import Client -from transformers import Pipeline, pipeline + +if TYPE_CHECKING: + from transformers import Pipeline from .objects_ import ClassificationConfig, TranslationConfig @@ -236,6 +239,7 @@ async def translate_docs( config: TranslationConfig | None = None, ) -> int: import torch # noqa:PLC0415 + from transformers import pipeline # noqa: PLC0415 if config is None: config = TranslationConfig() @@ -301,6 +305,7 @@ async def classify_docs( es_client: ESClient, ) -> int: import torch # noqa: PLC0415 + from transformers import pipeline # noqa: PLC0415 if config is None: config = ClassificationConfig() @@ -357,11 +362,11 @@ async def classify_docs( return n_docs -def _translate_as_list(pipe: Pipeline, texts: list[str]) -> list[str]: +def _translate_as_list(pipe: "Pipeline", texts: list[str]) -> list[str]: return list(_translate(pipe, texts)) -def _translate(pipe: Pipeline, texts: list[str]) -> Generator[str, None, None]: +def _translate(pipe: "Pipeline", texts: list[str]) -> Generator[str, None, None]: for res in pipe(texts): yield res["translation_text"] @@ -455,13 +460,13 @@ async def _count_untranslated( return res[COUNT] -def _classify(pipe: Pipeline, texts: list[str]) -> Generator[str, None, None]: +def _classify(pipe: "Pipeline", texts: list[str]) -> Generator[str, None, None]: # In practice, we should chunk the text for res in pipe(texts, padding=True, truncation=True): yield res["label"] -def _classify_as_list(pipe: Pipeline, texts: list[str]) -> list[str]: +def _classify_as_list(pipe: "Pipeline", texts: list[str]) -> list[str]: return list(_classify(pipe, texts)) From 527e75a666872986df2fe8e6506d4b927bf7b31e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Tue, 31 Mar 2026 15:16:10 +0200 Subject: [PATCH 3/5] feature(datashare-python): initialize activity with event loop and temporal worker --- .../datashare_python/cli/worker.py | 3 +- datashare-python/datashare_python/worker.py | 32 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/datashare-python/datashare_python/cli/worker.py b/datashare-python/datashare_python/cli/worker.py index 4d8672d..eaffa89 100644 --- a/datashare-python/datashare_python/cli/worker.py +++ b/datashare-python/datashare_python/cli/worker.py @@ -13,7 +13,7 @@ from datashare_python.dependencies import with_dependencies from datashare_python.discovery import discover, discover_activities, discover_workflows from datashare_python.types_ import TemporalClient -from datashare_python.worker import create_worker_id, datashare_worker +from datashare_python.worker import create_worker_id, datashare_worker, init_activity from .utils import AsyncTyper @@ -130,6 +130,7 @@ async def start( ) async with deps_cm: client = await TemporalClient.connect(temporal_address, namespace=namespace) + acts = [init_activity(a, client=client, event_loop=event_loop) for a in acts] worker = datashare_worker( client, worker_id, diff --git a/datashare-python/datashare_python/worker.py b/datashare-python/datashare_python/worker.py index 6701269..fc8e8c6 100644 --- a/datashare-python/datashare_python/worker.py +++ b/datashare-python/datashare_python/worker.py @@ -1,7 +1,10 @@ +import inspect import logging import os import socket import threading +from asyncio import AbstractEventLoop +from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor from temporalio.worker import PollerBehaviorSimpleMaximum, Worker @@ -85,3 +88,32 @@ def create_worker_id(prefix: str) -> str: hostname = socket.gethostname() # TODO: this might not be unique when using asyncio return f"{prefix}-{hostname}-{pid}-{threadid}" + + +_CLIENT = "client" +_WORKER = "worker" +_EXPECTED_INIT_ARGS = {"self", _CLIENT, _WORKER} + + +def _get_class_from_method(method: Callable) -> type: + qualname = method.__qualname__ + class_name = qualname.rsplit(".", 1)[0] + return method.__globals__.get(class_name) + + +def init_activity( + activity: Callable, client: TemporalClient, event_loop: AbstractEventLoop +) -> Callable: + if not inspect.ismethod(activity): + return activity + cls = _get_class_from_method(activity) + init_args = inspect.signature(cls.__init__).parameters + invalid = [p for p in init_args if p not in _EXPECTED_INIT_ARGS] + if invalid: + msg = f"invalid activity arguments: {invalid}" + raise ValueError(msg) + kwargs = {"client": client, "event_loop": event_loop} + kwargs = {k: v for k, v in kwargs.items() if k in _EXPECTED_INIT_ARGS} + if not kwargs: + return activity + return cls(**kwargs) From fc2bc428d692d223c046d7d41280c095ad792681 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Wed, 1 Apr 2026 09:45:06 +0200 Subject: [PATCH 4/5] refactor(worker-template): use lifespan dependencies rather than init variables for activities --- .../datashare_python/cli/worker.py | 53 ++----- datashare-python/datashare_python/config.py | 2 + datashare-python/datashare_python/worker.py | 92 +++++++++--- worker-template/tests/conftest.py | 133 ++++++++---------- worker-template/worker_template/activities.py | 51 ++----- .../worker_template/dependencies.py | 5 +- 6 files changed, 160 insertions(+), 176 deletions(-) diff --git a/datashare-python/datashare_python/cli/worker.py b/datashare-python/datashare_python/cli/worker.py index eaffa89..8658691 100644 --- a/datashare-python/datashare_python/cli/worker.py +++ b/datashare-python/datashare_python/cli/worker.py @@ -1,7 +1,5 @@ import asyncio import logging -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager from pathlib import Path from typing import Annotated @@ -10,10 +8,9 @@ from datashare_python.config import WorkerConfig from datashare_python.constants import DEFAULT_NAMESPACE, DEFAULT_TEMPORAL_ADDRESS -from datashare_python.dependencies import with_dependencies from datashare_python.discovery import discover, discover_activities, discover_workflows from datashare_python.types_ import TemporalClient -from datashare_python.worker import create_worker_id, datashare_worker, init_activity +from datashare_python.worker import bootstrap_worker, create_worker_id from .utils import AsyncTyper @@ -34,10 +31,6 @@ " if not provided will load worker configuration from env variables" ) _WORKER_QUEUE_HELP = "worker task queue" -_WORKER_MAX_ACTIVITIES_HELP = ( - "maximum number of concurrent activities/tasks" - " concurrently run by the worker. Defaults to 1 to encourage horizontal scaling." -) _TEMPORAL_NAMESPACE_HELP = "worker temporal namespace" _TEMPORAL_URL_HELP = "address for temporal server" @@ -104,9 +97,6 @@ async def start( namespace: Annotated[ str, typer.Option("--temporal-namespace", "-ns", help=_TEMPORAL_NAMESPACE_HELP) ] = DEFAULT_NAMESPACE, - max_concurrent_activities: Annotated[ - int, typer.Option("--max-activities", help=_WORKER_MAX_ACTIVITIES_HELP) - ] = 1, ) -> None: if config_path is not None: with config_path.open() as f: @@ -115,37 +105,24 @@ async def start( ) else: bootstrap_config = WorkerConfig() - wfs, acts, deps = discover(workflows, act_names=activities, deps_name=dependencies) + registered_wfs, registered_acts, registered_deps = discover( + workflows, act_names=activities, deps_name=dependencies + ) worker_id = create_worker_id(worker_id_prefix or "worker") + client = await TemporalClient.connect(temporal_address, namespace=namespace) event_loop = asyncio.get_event_loop() - deps_cm = ( - with_dependencies( - deps, - worker_config=bootstrap_config, - worker_id=worker_id, - event_loop=event_loop, - ) - if deps - else _do_nothing_cm - ) - async with deps_cm: - client = await TemporalClient.connect(temporal_address, namespace=namespace) - acts = [init_activity(a, client=client, event_loop=event_loop) for a in acts] - worker = datashare_worker( - client, - worker_id, - workflows=wfs, - activities=acts, - task_queue=queue, - max_concurrent_activities=max_concurrent_activities, - ) + async with bootstrap_worker( + worker_id, + activities=registered_acts, + workflows=registered_wfs, + dependencies=registered_deps, + bootstrap_config=bootstrap_config, + client=client, + event_loop=event_loop, + task_queue=queue, + ) as worker: try: await worker.run() except Exception as e: # noqa: BLE001 await worker.shutdown() raise e - - -@asynccontextmanager -async def _do_nothing_cm() -> AsyncGenerator[None, None]: - yield diff --git a/datashare-python/datashare_python/config.py b/datashare-python/datashare_python/config.py index e77fc5f..c48ce49 100644 --- a/datashare-python/datashare_python/config.py +++ b/datashare-python/datashare_python/config.py @@ -83,6 +83,8 @@ class WorkerConfig(ICIJSettings, LogWithWorkerIDMixin, BaseModel): elasticsearch: ESClientConfig = ESClientConfig() temporal: TemporalClientConfig = TemporalClientConfig() + max_concurrent_io_activities: int = 5 + def to_es_client(self) -> ESClient: return self.elasticsearch.to_es_client(self.datashare.api_key) diff --git a/datashare-python/datashare_python/worker.py b/datashare-python/datashare_python/worker.py index fc8e8c6..ad83590 100644 --- a/datashare-python/datashare_python/worker.py +++ b/datashare-python/datashare_python/worker.py @@ -2,18 +2,27 @@ import logging import os import socket +import sys import threading from asyncio import AbstractEventLoop -from collections.abc import Callable +from collections.abc import AsyncGenerator, Callable from concurrent.futures import ThreadPoolExecutor +from contextlib import asynccontextmanager +from typing import Any from temporalio.worker import PollerBehaviorSimpleMaximum, Worker +from .config import WorkerConfig +from .dependencies import with_dependencies from .discovery import Activity -from .types_ import TemporalClient +from .types_ import ContextManagerFactory, TemporalClient logger = logging.getLogger(__name__) +_TEMPORAL_CLIENT = "temporal_client" +_EVENT_LOOP = "event_loop" +_EXPECTED_INIT_ARGS = {"self", _TEMPORAL_CLIENT, _EVENT_LOOP, "args", "kwargs"} + _SEPARATE_IO_AND_CPU_WORKERS = """The worker will run sync (CPU-bound) activities as \ well as IO-bound workflows. To avoid deadlocks due to the GIL, we advise to run all CPU-bound activities inside a \ @@ -43,7 +52,7 @@ def datashare_worker( task_queue: str, # Scale horizontally be default for activities, each worker processes one activity # at a time - max_concurrent_activities: int = 1, + max_concurrent_io_activities: int = 10, ) -> Worker: if workflows is None: workflows = [] @@ -62,8 +71,11 @@ def datashare_worker( ) logger.warning(_SEPARATE_IO_AND_CPU_ACTIVITIES) - if isinstance(activity_executor, ThreadPoolExecutor) and workflows: - logger.warning(_SEPARATE_IO_AND_CPU_WORKERS) + max_concurrent_activities = max_concurrent_io_activities + if isinstance(activity_executor, ThreadPoolExecutor): + max_concurrent_activities = 1 + if workflows: + logger.warning(_SEPARATE_IO_AND_CPU_WORKERS) return Worker( client, @@ -90,21 +102,11 @@ def create_worker_id(prefix: str) -> str: return f"{prefix}-{hostname}-{pid}-{threadid}" -_CLIENT = "client" -_WORKER = "worker" -_EXPECTED_INIT_ARGS = {"self", _CLIENT, _WORKER} - - -def _get_class_from_method(method: Callable) -> type: - qualname = method.__qualname__ - class_name = qualname.rsplit(".", 1)[0] - return method.__globals__.get(class_name) - - def init_activity( activity: Callable, client: TemporalClient, event_loop: AbstractEventLoop ) -> Callable: - if not inspect.ismethod(activity): + is_object_method = "." not in activity.__qualname__ + if is_object_method: return activity cls = _get_class_from_method(activity) init_args = inspect.signature(cls.__init__).parameters @@ -112,8 +114,60 @@ def init_activity( if invalid: msg = f"invalid activity arguments: {invalid}" raise ValueError(msg) - kwargs = {"client": client, "event_loop": event_loop} + kwargs = {_TEMPORAL_CLIENT: client, _EVENT_LOOP: event_loop} kwargs = {k: v for k, v in kwargs.items() if k in _EXPECTED_INIT_ARGS} if not kwargs: return activity - return cls(**kwargs) + act_instance = cls(**kwargs) + act_method = getattr(act_instance, activity.__name__) + return act_method + + +@asynccontextmanager +async def bootstrap_worker( + worker_id: str, + *, + activities: list[Callable[..., Any] | None] | None = None, + workflows: list[type] | None = None, + bootstrap_config: WorkerConfig, + client: TemporalClient, + event_loop: AbstractEventLoop, + task_queue: str, + dependencies: list[ContextManagerFactory] | None = None, +) -> AsyncGenerator[Worker, None]: + deps_cm = ( + with_dependencies( + dependencies, + worker_config=bootstrap_config, + worker_id=worker_id, + event_loop=event_loop, + ) + if dependencies + else _do_nothing_cm + ) + async with deps_cm: + if activities is not None: + acts = [ + init_activity(a, client=client, event_loop=event_loop) + for a in activities + ] + worker = datashare_worker( + client, + worker_id, + workflows=workflows, + activities=acts, + task_queue=task_queue, + max_concurrent_io_activities=bootstrap_config.max_concurrent_io_activities, + ) + yield worker + + +@asynccontextmanager +async def _do_nothing_cm() -> AsyncGenerator[None, None]: + yield + + +def _get_class_from_method(method: Callable) -> type: + class_name = method.__qualname__.rsplit(".", 1)[0] + module = sys.modules[method.__module__] + return getattr(module, class_name) diff --git a/worker-template/tests/conftest.py b/worker-template/tests/conftest.py index 1e5ff58..bc05a67 100644 --- a/worker-template/tests/conftest.py +++ b/worker-template/tests/conftest.py @@ -2,7 +2,6 @@ import uuid from asyncio import AbstractEventLoop from collections.abc import AsyncGenerator -from concurrent.futures import ThreadPoolExecutor from typing import Any import pytest @@ -33,10 +32,9 @@ with_dependencies, ) from datashare_python.types_ import ContextManagerFactory -from icij_common.es import ESClient +from datashare_python.worker import bootstrap_worker from temporalio.client import Client as TemporalClient from temporalio.testing import ActivityEnvironment -from temporalio.worker import Worker from worker_template.activities import ( ClassifyDocs, CreateClassificationBatches, @@ -81,36 +79,31 @@ async def lifetime_deps( @pytest.fixture(scope="session") async def io_worker( - test_es_client_session: ESClient, # noqa: F811 + test_worker_config: WorkerConfig, # noqa: F811 test_temporal_client_session: TemporalClient, # noqa: F811 event_loop: asyncio.AbstractEventLoop, # noqa: F811 + test_deps: list[ContextManagerFactory], # noqa: F811 ) -> AsyncGenerator[None, None]: - es_client = test_es_client_session - temporal_client = test_temporal_client_session + client = test_temporal_client_session worker_id = f"test-io-worker-{uuid.uuid4()}" - pong_activity = Pong(temporal_client=temporal_client, event_loop=event_loop) + pong_activity = Pong(temporal_client=client, event_loop=event_loop) io_activities = [ pong_activity.pong, - CreateTranslationBatches( - es_client=es_client, - temporal_client=temporal_client, - event_loop=event_loop, - ).create_translation_batches, - CreateClassificationBatches( - es_client=es_client, - temporal_client=temporal_client, - event_loop=event_loop, - ).create_classification_batches, + CreateTranslationBatches.create_translation_batches, + CreateClassificationBatches.create_classification_batches, ] workflows = [PingWorkflow, TranslateAndClassifyWorkflow] - worker = Worker( - temporal_client, - identity=worker_id, - task_queue=TaskQueues.CPU, + task_queue = TaskQueues.CPU + async with bootstrap_worker( + worker_id, activities=io_activities, workflows=workflows, - ) - async with worker: + bootstrap_config=test_worker_config, + client=client, + event_loop=event_loop, + task_queue=task_queue, + dependencies=test_deps, + ) as worker: t = None try: t = asyncio.create_task(worker.run()) @@ -122,70 +115,60 @@ async def io_worker( @pytest.fixture(scope="session") async def translation_worker( - test_es_client_session: ESClient, # noqa: F811 + test_worker_config: WorkerConfig, # noqa: F811 test_temporal_client_session: TemporalClient, # noqa: F811 event_loop: asyncio.AbstractEventLoop, # noqa: F811 + test_deps: list[ContextManagerFactory], # noqa: F811 ) -> AsyncGenerator[None, None]: - es_client = test_es_client_session - temporal_client = test_temporal_client_session + client = test_temporal_client_session worker_id = f"test-translation-worker-{uuid.uuid4()}" - translation_activities = [ - TranslateDocs( - es_client=es_client, - temporal_client=temporal_client, - event_loop=event_loop, - ).translate_docs, - ] - with ThreadPoolExecutor() as executor: - worker = Worker( - temporal_client, - identity=worker_id, - task_queue=TaskQueues.TRANSLATE_GPU, - activities=translation_activities, - activity_executor=executor, - ) - async with worker: - t = None - try: - t = asyncio.create_task(worker.run()) - yield - except Exception: # noqa: BLE001 - if t is not None: - t.cancel() + translation_activities = [TranslateDocs.translate_docs] + task_queue = TaskQueues.TRANSLATE_GPU + async with bootstrap_worker( + worker_id, + activities=translation_activities, + bootstrap_config=test_worker_config, + client=client, + event_loop=event_loop, + task_queue=task_queue, + dependencies=test_deps, + ) as worker: + t = None + try: + t = asyncio.create_task(worker.run()) + yield + except Exception: # noqa: BLE001 + if t is not None: + t.cancel() @pytest.fixture(scope="session") async def classification_worker( - test_es_client_session: ESClient, # noqa: F811 + test_worker_config: WorkerConfig, test_temporal_client_session: TemporalClient, # noqa: F811 event_loop: asyncio.AbstractEventLoop, # noqa: F811 + test_deps: list[ContextManagerFactory], # noqa: F811 ) -> AsyncGenerator[None, None]: - es_client = test_es_client_session - temporal_client = test_temporal_client_session + client = test_temporal_client_session worker_id = f"test-classification-worker-{uuid.uuid4()}" - classification_activities = [ - ClassifyDocs( - es_client=es_client, - temporal_client=temporal_client, - event_loop=event_loop, - ).classify_docs, - ] - with ThreadPoolExecutor() as executor: - worker = Worker( - temporal_client, - identity=worker_id, - task_queue=TaskQueues.CLASSIFY_GPU, - activities=classification_activities, - activity_executor=executor, - ) - async with worker: - t = None - try: - t = asyncio.create_task(worker.run()) - yield - except Exception: # noqa: BLE001 - if t is not None: - t.cancel() + classification_activities = [ClassifyDocs.classify_docs] + task_queue = TaskQueues.CLASSIFY_GPU + async with bootstrap_worker( + worker_id, + activities=classification_activities, + bootstrap_config=test_worker_config, + client=client, + event_loop=event_loop, + task_queue=task_queue, + dependencies=test_deps, + ) as worker: + t = None + try: + t = asyncio.create_task(worker.run()) + yield + except Exception: # noqa: BLE001 + if t is not None: + t.cancel() @pytest.fixture(scope="session") diff --git a/worker-template/worker_template/activities.py b/worker-template/worker_template/activities.py index cfa94bb..dc1cce8 100644 --- a/worker-template/worker_template/activities.py +++ b/worker-template/worker_template/activities.py @@ -40,7 +40,8 @@ must_not, ) from temporalio import activity -from temporalio.client import Client + +from worker_template.dependencies import lifespan_es_client if TYPE_CHECKING: from transformers import Pipeline @@ -55,60 +56,35 @@ async def pong(self) -> str: class CreateTranslationBatches(ActivityWithProgress): - def __init__( - self, - es_client: ESClient, - temporal_client: Client, - event_loop: asyncio.AbstractEventLoop, - ): - super().__init__(temporal_client, event_loop) - self._es_client = es_client - @activity_defn(name="create-translation-batches") async def create_translation_batches( self, project: str, target_language: str, batch_size: int ) -> list[list[str]]: + es_client = lifespan_es_client() return await create_translation_batches( project=project, target_language=target_language, batch_size=batch_size, - es_client=self._es_client, + es_client=es_client, ) class CreateClassificationBatches(ActivityWithProgress): - def __init__( - self, - es_client: ESClient, - temporal_client: Client, - event_loop: asyncio.AbstractEventLoop, - ): - super().__init__(temporal_client, event_loop) - self._es_client = es_client - @activity_defn(name="create-classification-batches") async def create_classification_batches( self, project: str, target_language: str, config: ClassificationConfig ) -> list[list[str]]: + es_client = lifespan_es_client() return await create_classification_batches( project=project, language=target_language, config=config, - es_client=self._es_client, + es_client=es_client, logger=activity.logger, ) class TranslateDocs(ActivityWithProgress): - def __init__( - self, - es_client: ESClient, - temporal_client: Client, - event_loop: asyncio.AbstractEventLoop, - ): - super().__init__(temporal_client, event_loop) - self._es_client = es_client - @activity_defn(name="translate-docs") def translate_docs( self, @@ -119,12 +95,13 @@ def translate_docs( config: TranslationConfig, progress: ProgressRateHandler | None = None, ) -> int: + es_client = lifespan_es_client() return self._event_loop.run_until_complete( translate_docs( docs, target_language=target_language, project=project, - es_client=self._es_client, + es_client=es_client, config=config, progress=progress, ) @@ -132,15 +109,6 @@ def translate_docs( class ClassifyDocs(ActivityWithProgress): - def __init__( - self, - es_client: ESClient, - temporal_client: Client, - event_loop: asyncio.AbstractEventLoop, - ): - super().__init__(temporal_client, event_loop) - self._es_client = es_client - @activity_defn(name="classify-docs") async def classify_docs( self, @@ -151,11 +119,12 @@ async def classify_docs( config: ClassificationConfig, progress: ProgressRateHandler | None = None, ) -> int: + es_client = lifespan_es_client() return await classify_docs( docs, classified_language=classified_language, project=project, - es_client=self._es_client, + es_client=es_client, config=config, progress=progress, ) diff --git a/worker-template/worker_template/dependencies.py b/worker-template/worker_template/dependencies.py index 64dd6d3..550a100 100644 --- a/worker-template/worker_template/dependencies.py +++ b/worker-template/worker_template/dependencies.py @@ -1,10 +1,9 @@ from datashare_python.dependencies import ( + lifespan_es_client, # noqa: F401 set_es_client, - set_event_loop, set_loggers, - set_temporal_client, ) -BASE = [set_loggers, set_event_loop, set_es_client, set_temporal_client] +BASE = [set_loggers, set_es_client] DEPENDENCIES = {"base": BASE} From 3e616a510a4c11cf5c2dd1fdb4e26350104aeee9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Wed, 1 Apr 2026 11:57:03 +0200 Subject: [PATCH 5/5] doc: document dependency injection --- .../datashare_python/dependencies.py | 34 ++++- datashare-python/tests/test_dependencies.py | 33 +++++ docs/get-started/implement/worker-basic.md | 4 +- docs/guides/dependency-injection.md | 129 +++++++++++++++++- docs/src/asr_activity.py | 11 ++ docs/src/dependencies.py | 46 +++++++ docs/src/naive_asr_activity.py | 12 ++ docs/src/naive_dependencies.py | 12 ++ docs/src/pyproject.toml | 6 + 9 files changed, 277 insertions(+), 10 deletions(-) create mode 100644 datashare-python/tests/test_dependencies.py create mode 100644 docs/src/asr_activity.py create mode 100644 docs/src/dependencies.py create mode 100644 docs/src/naive_asr_activity.py create mode 100644 docs/src/naive_dependencies.py diff --git a/datashare-python/datashare_python/dependencies.py b/datashare-python/datashare_python/dependencies.py index 8279a56..00ae9f4 100644 --- a/datashare-python/datashare_python/dependencies.py +++ b/datashare-python/datashare_python/dependencies.py @@ -1,8 +1,11 @@ +import inspect import logging from asyncio import AbstractEventLoop, iscoroutine -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from contextlib import AsyncExitStack, asynccontextmanager from contextvars import ContextVar +from copy import deepcopy +from typing import Any from icij_common.es import ESClient @@ -20,7 +23,7 @@ TEMPORAL_CLIENT: ContextVar[TemporalClient] = ContextVar("temporal_client") -def set_event_loop(event_loop: AbstractEventLoop, **_) -> None: +def set_event_loop(event_loop: AbstractEventLoop) -> None: EVENT_LOOP.set(event_loop) @@ -31,13 +34,13 @@ def lifespan_event_loop() -> AbstractEventLoop: raise DependencyInjectionError("event loop") from e -def set_loggers(worker_config: WorkerConfig, worker_id: str, **_) -> None: +def set_loggers(worker_config: WorkerConfig, worker_id: str) -> None: worker_config.setup_loggers(worker_id=worker_id) logger.info("worker loggers ready to log 💬") logger.info("app config: %s", worker_config.model_dump_json(indent=2)) -async def set_es_client(worker_config: WorkerConfig, **_) -> ESClient: +async def set_es_client(worker_config: WorkerConfig) -> ESClient: client = worker_config.to_es_client() ES_CLIENT.set(client) return client @@ -52,7 +55,7 @@ def lifespan_es_client() -> ESClient: # Task client setup -async def set_task_client(worker_config: WorkerConfig, **_) -> DatashareTaskClient: +async def set_task_client(worker_config: WorkerConfig) -> DatashareTaskClient: task_client = worker_config.to_task_client() TASK_CLIENT.set(task_client) return task_client @@ -67,7 +70,7 @@ def lifespan_task_client() -> DatashareTaskClient: # Temporal client setup -async def set_temporal_client(worker_config: WorkerConfig, **_) -> None: +async def set_temporal_client(worker_config: WorkerConfig) -> None: client = await worker_config.to_temporal_client() TEMPORAL_CLIENT.set(client) @@ -86,7 +89,7 @@ async def with_dependencies( ) -> AsyncGenerator[None, None]: async with AsyncExitStack() as stack: for dep in dependencies: - cm = dep(**kwargs) + cm = dep(**add_missing_args(dep, kwargs)) if hasattr(cm, "__aenter__"): await stack.enter_async_context(cm) elif hasattr(cm, "__enter__"): @@ -94,3 +97,20 @@ async def with_dependencies( elif iscoroutine(cm): await cm yield + + +def add_missing_args(fn: Callable, args: dict[str, Any], **kwargs) -> dict[str, Any]: + # We make the choice not to raise in case of missing argument here, the error will + # be correctly raise when the function is called + from_kwargs = dict() + sig = inspect.signature(fn) + for param_name in sig.parameters: + if param_name in args: + continue + kwargs_value = kwargs.get(param_name) + if kwargs_value is not None: + from_kwargs[param_name] = kwargs_value + if from_kwargs: + args = deepcopy(args) + args.update(from_kwargs) + return args diff --git a/datashare-python/tests/test_dependencies.py b/datashare-python/tests/test_dependencies.py new file mode 100644 index 0000000..e8d1693 --- /dev/null +++ b/datashare-python/tests/test_dependencies.py @@ -0,0 +1,33 @@ +from typing import Any + +import pytest +from datashare_python.dependencies import add_missing_args + + +@pytest.mark.parametrize( + ("provided_args", "kwargs", "maybe_output"), + [ + ({}, {}, None), + ({"a": "a"}, {}, None), + ({"a": "a"}, {"b": "b"}, "a-b-c"), + ({"a": "a", "b": "b"}, {"c": "not-your-average-c"}, "a-b-not-your-average-c"), + ], +) +def test_add_missing_args( + provided_args: dict[str, Any], + kwargs: dict[str, Any], + maybe_output: str | None, +) -> None: + # Given + def fn(a: str, b: str, c: str = "c") -> str: + return f"{a}-{b}-{c}" + + # When + all_args = add_missing_args(fn, args=provided_args, **kwargs) + # Then + if maybe_output is not None: + output = fn(**all_args) + assert output == maybe_output + else: + with pytest.raises(TypeError): + fn(**all_args) diff --git a/docs/get-started/implement/worker-basic.md b/docs/get-started/implement/worker-basic.md index b8a7829..e7bf26e 100644 --- a/docs/get-started/implement/worker-basic.md +++ b/docs/get-started/implement/worker-basic.md @@ -35,7 +35,7 @@ basic_activities.py:activities 1. decorate the activity function with `@activity_defn` using `"hello-user"` as name 2. expose the activity function to `datashare-python`'s CLI, by listing it in the `ACTIVITIES` variable -Under the hood, the `ACTIVITIES` is registered as +Under the hood, the `ACTIVITIES` variable is registered as [plugin entrypoint](https://setuptools.pypa.io/en/latest/userguide/entry_point.html#entry-points-for-plugins) in the package's `pyproject.toml`: ```toml title="pyproject.toml" @@ -46,7 +46,7 @@ pyproject.toml:entry_points_acts When running a worker using `datashare-python worker start` CLI, `datashare-python` will look for any variable registered under the `:::toml "datashare.activities"` key and will be able to run activities registered in these variables. -You can register as many variables as you want, under the names of your choices, as long as it's registered under +You can register as many variables as you want, under the names of your choices, as long as it's bound under the `:::toml "datashare.activities"` key. ## Implement and register a workflow diff --git a/docs/guides/dependency-injection.md b/docs/guides/dependency-injection.md index ff33901..df4e996 100644 --- a/docs/guides/dependency-injection.md +++ b/docs/guides/dependency-injection.md @@ -1 +1,128 @@ -# Dependency injection (WIP) \ No newline at end of file +# Dependency injection + +## What are dependencies ? + +We call "lifepsan dependency" or just "dependency" any object used by an activity workers and which needs to live longer +then the activity duration. + +Common dependencies include: +- clients +- connections and connection pools +- ML models + +Dependency injection make sure that these dependencies are **available** in your activity code and also makes sure +**resources are correctly freed** when they are no longer needed. + +## How `datashare-python` dependency injection works ? + +`datashare-python` dependency injection is inspired by [fastapi](https://fastapi.tiangolo.com/advanced/events/)'s +lifespan events handling. + +The idea is to: +1. provide an number of functions and/or context managers initializing dependencies and storing them in the worker +thread [context](https://docs.python.org/3/library/contextvars.html) +2. define functions access this context and letting you access these variables in your code + +## Providing dependencies + +If you are building an automatic speech recognition worker, you might implement the following activity: +```python title="activities.py" +--8<-- +naive_asr_activity.py +--8<-- +``` + +1. this is awfully heavy, we don't want to reload the model everytime ! + +The obvious problem with this implementation is that we'll reload the model each time we receive audios to process. +Ideally we'd like to have the model **preloaded in memory and just run inference**. + +Instead, we'll define a lifespan dependency which loads the model and stores it into the worker thread content +variables: + +```python title="dependency.py" +--8<-- +naive_dependencies.py +--8<-- +``` + +1. register a context variable with the `ml_model` name +2. load the model +3. store the model into the registered context variable + +A **better version** of this dependency uses context manager to **make sure resource are freed** when worker no longer +needs the dependency: + +```python title="dependency.py" +--8<-- +dependencies.py:context_manager +--8<-- +``` + +1. let the calling code run +2. clean everything up when the caller is done + +## Accessing dependencies + +Now that we've registered our dependency in the thread context, we need to update our activity to access the context +variable. We can do it directly by calling `:::python ContextVar("ml_model").get()`, but we can more elegantly define the +following dependency function: + +```python title="dependency.py" +--8<-- +dependencies.py:expose_dependency +--8<-- +``` + +Next, we'll use this function in our activity: +```python title="dependency.py" +--8<-- +asr_activity.py +--8<-- +``` + +1. load cached model rather than reloading it + +## Worker dependency discovery + +In order for dependencies to by discoverable by `datashare-python`'s CLI, they need to be registered. + +```python title="dependency.py" +--8<-- +dependencies.py:registry +--8<-- +``` + +Under the hood, the `DEPENDENCIES` variable is registered as [plugin entrypoint](https://setuptools.pypa.io/en/latest/userguide/entry_point.html#entry-points-for-plugins) + +```toml title="pyproject.toml" +--8<-- +pyproject.toml:entry_points_deps +--8<-- +``` + +When running a worker using `datashare-python worker start` CLI, `datashare-python` will look for any variable registered under +the `:::toml "datashare.dependencies"` key and the `dependencies` entry point name and will be able to run activities registered in these variables. + +You can register as dependency sets as you want in the bounded variable. You can use the variable name of your choice +for the dict registry, as long as it's bound under the `:::toml "datashare.dependencies"` key the `:::toml dependencies` entry point name. +**** + +## Selecting dependencies when running `datashare-python`'s CLI + +When running an activity worker using + + +```console +datashare-python worker start --activities asr-transcription +``` + +the `datashare-python` will auto discover dependencies and if the registry has a single entry in it, it will +automatically use this dependency sets. + +In case your registry contains multiple dependency sets, you can provide call the CLI providing the set's key (here `:::python "base"`) as argument: + + +```console +datashare-python worker start --activities asr-transcription --dependencies base +``` diff --git a/docs/src/asr_activity.py b/docs/src/asr_activity.py new file mode 100644 index 0000000..a246fc0 --- /dev/null +++ b/docs/src/asr_activity.py @@ -0,0 +1,11 @@ +from pathlib import Path + +from utils import activity_defn + +from .dependencies import lifespan_ml_model + + +@activity_defn(name="asr-transcription") +def asr_activity(audios: list[Path]) -> list: + ml_model = lifespan_ml_model() # (1)! + return ml_model.transcribe(audios) diff --git a/docs/src/dependencies.py b/docs/src/dependencies.py new file mode 100644 index 0000000..b9e74a5 --- /dev/null +++ b/docs/src/dependencies.py @@ -0,0 +1,46 @@ +import gc +from collections.abc import Generator +from contextlib import contextmanager +from contextvars import ContextVar + +import torch +from datashare_python.exceptions import DependencyInjectionError +from transformers import CohereAsrForConditionalGeneration + +ML_MODEL: ContextVar[dict | None] = ContextVar("ml_model") + + +# --8<-- [start:context_manager] +@contextmanager +def load_ml_model() -> Generator[None, None, None]: + ml_model = CohereAsrForConditionalGeneration.from_pretrained( + "CohereLabs/cohere-transcribe-03-2026", device_map="auto" + ) + ML_MODEL.set(ml_model) + try: + yield # (1)! + finally: # (2)! + del ml_model + torch.cuda.empty_cache() + gc.collect() + ML_MODEL.set(None) + + +# --8<-- [end:context_manager] + + +# --8<-- [start:expose_dependency] +def lifespan_ml_model() -> CohereAsrForConditionalGeneration: + try: + return ML_MODEL.get() + except LookupError as e: + raise DependencyInjectionError("ml model") from e + + +# --8<-- [end:expose_dependency] + +# --8<-- [start:registry] +DEPENDENCIES = { + "base": [lifespan_ml_model], +} +# --8<-- [end:registry] diff --git a/docs/src/naive_asr_activity.py b/docs/src/naive_asr_activity.py new file mode 100644 index 0000000..9ed8ea0 --- /dev/null +++ b/docs/src/naive_asr_activity.py @@ -0,0 +1,12 @@ +from pathlib import Path + +from transformers import CohereAsrForConditionalGeneration +from utils import activity_defn + + +@activity_defn(name="asr-transcription") +def asr_activity(audios: list[Path]) -> list: + ml_model = CohereAsrForConditionalGeneration.from_pretrained( # (1)! + "CohereLabs/cohere-transcribe-03-2026", device_map="auto" + ) + return ml_model.transcribe(audios) diff --git a/docs/src/naive_dependencies.py b/docs/src/naive_dependencies.py new file mode 100644 index 0000000..a8c8556 --- /dev/null +++ b/docs/src/naive_dependencies.py @@ -0,0 +1,12 @@ +from contextvars import ContextVar + +from transformers import CohereAsrForConditionalGeneration + +ML_MODEL: ContextVar[dict | None] = ContextVar("ml_model") # (1)! + + +def load_ml_model() -> None: + ml_model = CohereAsrForConditionalGeneration.from_pretrained( # (2)! + "CohereLabs/cohere-transcribe-03-2026", device_map="auto" + ) + ML_MODEL.set(ml_model) # (3)! diff --git a/docs/src/pyproject.toml b/docs/src/pyproject.toml index 37f603f..5fed172 100644 --- a/docs/src/pyproject.toml +++ b/docs/src/pyproject.toml @@ -8,4 +8,10 @@ workflows = "basic_worker.workflows:WORKFLOWS" [project.entry-points."datashare.activities"] activities = "basic_worker.activities:ACTIVITIES" # --8<-- [end:entry_points_acts] + +# --8<-- [start:entry_points_deps] +[project.entry-points."datashare.dependencies"] +dependencies = "asr_worker.dependencies:DEPENDENCIES" +# --8<-- [end:entry_points_deps] + # --8<-- [start:entry_points]