diff --git a/datashare-python/datashare_python/cli/worker.py b/datashare-python/datashare_python/cli/worker.py index 4e6ac42..eaffa89 100644 --- a/datashare-python/datashare_python/cli/worker.py +++ b/datashare-python/datashare_python/cli/worker.py @@ -1,12 +1,19 @@ +import asyncio import logging +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from pathlib import Path from typing import Annotated import typer +import yaml +from datashare_python.config import WorkerConfig from datashare_python.constants import DEFAULT_NAMESPACE, DEFAULT_TEMPORAL_ADDRESS -from datashare_python.discovery import discover_activities, discover_workflows +from datashare_python.dependencies import with_dependencies +from datashare_python.discovery import discover, discover_activities, discover_workflows from datashare_python.types_ import TemporalClient -from datashare_python.worker import datashare_worker +from datashare_python.worker import create_worker_id, datashare_worker, init_activity from .utils import AsyncTyper @@ -20,6 +27,12 @@ _START_WORKER_WORKFLOWS_HELP = "workflow names run by the worker (supports regexes)" _START_WORKER_ACTIVITIES_HELP = "activity names run by the worker (supports regexes)" +_START_WORKER_DEPS_HELP = "worker lifetime dependencies name in the registry" +_START_WORKER_WORKER_ID_PREFIX_HELP = "worker ID prefix" +_START_WORKER_CONFIG_PATH_HELP = ( + "path to a worker config YAML file," + " if not provided will load worker configuration from env variables" +) _WORKER_QUEUE_HELP = "worker task queue" _WORKER_MAX_ACTIVITIES_HELP = ( "maximum number of concurrent activities/tasks" @@ -73,6 +86,18 @@ async def start( workflows: Annotated[list[str], typer.Option(help=_START_WORKER_WORKFLOWS_HELP)], activities: Annotated[list[str], typer.Option(help=_START_WORKER_ACTIVITIES_HELP)], queue: Annotated[str, typer.Option("--queue", "-q", help=_WORKER_QUEUE_HELP)], + dependencies: Annotated[ + str | None, typer.Option(help=_START_WORKER_DEPS_HELP) + ] = None, + config_path: Annotated[ + Path | None, + typer.Option( + "--config-path", "--config", "-c", help=_START_WORKER_CONFIG_PATH_HELP + ), + ] = None, + worker_id_prefix: Annotated[ + str | None, typer.Option(help=_START_WORKER_WORKER_ID_PREFIX_HELP) + ] = None, temporal_address: Annotated[ str, typer.Option("--temporal-address", "-a", help=_TEMPORAL_URL_HELP) ] = DEFAULT_TEMPORAL_ADDRESS, @@ -83,32 +108,44 @@ async def start( int, typer.Option("--max-activities", help=_WORKER_MAX_ACTIVITIES_HELP) ] = 1, ) -> None: - wf_names, wfs = zip(*discover_workflows(workflows), strict=False) - registered = "" - if wf_names: - n_wfs = len(wf_names) - registered += ( - f"- {n_wfs} workflow{'s' if n_wfs > 1 else ''}: {','.join(wf_names)}" + if config_path is not None: + with config_path.open() as f: + bootstrap_config = WorkerConfig.model_validate( + yaml.load(f, Loader=yaml.Loader) + ) + else: + bootstrap_config = WorkerConfig() + wfs, acts, deps = discover(workflows, act_names=activities, deps_name=dependencies) + worker_id = create_worker_id(worker_id_prefix or "worker") + event_loop = asyncio.get_event_loop() + deps_cm = ( + with_dependencies( + deps, + worker_config=bootstrap_config, + worker_id=worker_id, + event_loop=event_loop, ) - act_names, acts = zip(*discover_activities(activities), strict=False) - if act_names: - if registered: - registered += "\n" - i = len(act_names) - registered += f"- {i} activit{'ies' if i > 1 else 'y'}: {','.join(act_names)}" - if not acts and not wfs: - raise ValueError("Couldn't find any registered activity or workflow.") - logger.info("Starting datashare worker running:\n%s", registered) - client = await TemporalClient.connect(temporal_address, namespace=namespace) - worker = datashare_worker( - client, - workflows=wfs, - activities=acts, - task_queue=queue, - max_concurrent_activities=max_concurrent_activities, + if deps + else _do_nothing_cm ) - try: - await worker.run() - except Exception as e: # noqa: BLE001 - await worker.shutdown() - raise e + async with deps_cm: + client = await TemporalClient.connect(temporal_address, namespace=namespace) + acts = [init_activity(a, client=client, event_loop=event_loop) for a in acts] + worker = datashare_worker( + client, + worker_id, + workflows=wfs, + activities=acts, + task_queue=queue, + max_concurrent_activities=max_concurrent_activities, + ) + try: + await worker.run() + except Exception as e: # noqa: BLE001 + await worker.shutdown() + raise e + + +@asynccontextmanager +async def _do_nothing_cm() -> AsyncGenerator[None, None]: + yield diff --git a/datashare-python/datashare_python/conftest.py b/datashare-python/datashare_python/conftest.py index 31fadb1..a96eee7 100644 --- a/datashare-python/datashare_python/conftest.py +++ b/datashare-python/datashare_python/conftest.py @@ -1,6 +1,7 @@ import asyncio from asyncio import AbstractEventLoop from collections.abc import AsyncGenerator, Generator, Iterator, Sequence +from pathlib import Path import aiohttp import pytest @@ -102,6 +103,13 @@ def test_worker_config() -> WorkerConfig: ) +@pytest.fixture +def test_worker_config_path(test_worker_config: WorkerConfig, tmpdir: Path) -> Path: + config_path = Path(tmpdir) / "config.json" + config_path.write_text(test_worker_config.model_dump_json()) + return config_path + + @pytest.fixture(scope="session") async def worker_lifetime_deps( event_loop: AbstractEventLoop, diff --git a/datashare-python/datashare_python/discovery.py b/datashare-python/datashare_python/discovery.py index 659dce8..be00611 100644 --- a/datashare-python/datashare_python/discovery.py +++ b/datashare-python/datashare_python/discovery.py @@ -1,16 +1,69 @@ +import logging import re from collections.abc import Callable, Iterable from importlib.metadata import entry_points +from .types_ import ContextManagerFactory from .utils import ActivityWithProgress +logger = logging.getLogger(__name__) + Activity = ActivityWithProgress | Callable | type +_DEPENDENCIES = "dependencies" _WORKFLOW_GROUPS = "datashare.workflows" _ACTIVITIES_GROUPS = "datashare.activities" +_DEPENDENCIES_GROUPS = "datashare.dependencies" + +_RegisteredWorkflow = tuple[str, type] +_RegisteredActivity = tuple[str, Activity] +_Dependencies = list[ContextManagerFactory] +_Discovery = tuple[ + Iterable[_RegisteredWorkflow] | None, + Iterable[_RegisteredActivity] | None, + _Dependencies | None, +] -def discover_workflows(names: list[str]) -> Iterable[tuple[str, type]]: +def discover( + wf_names: list[str] | None, *, act_names: list[str] | None, deps_name: str | None +) -> _Discovery: + discovered = "" + wfs = None + if wf_names is not None: + wf_names, wfs = zip(*discover_workflows(wf_names), strict=True) + if wf_names: + n_wfs = len(wf_names) + discovered += ( + f"- {n_wfs} workflow{'s' if n_wfs > 1 else ''}: {', '.join(wf_names)}" + ) + acts = None + if act_names is not None: + act_names, acts = zip(*discover_activities(act_names), strict=True) + if act_names: + if discovered: + discovered += "\n" + n_acts = len(act_names) + discovered += ( + f"- {n_acts} activit{'ies' if n_acts > 1 else 'y'}:" + f" {', '.join(act_names)}" + ) + if not acts and not wfs: + raise ValueError("Couldn't find any registered activity or workflow.") + deps = discover_dependencies(deps_name) + if deps: + n_deps = len(deps) + discovered += "\n" + deps_names = (d.__name__ for d in deps) + discovered += ( + f"- {n_deps} dependenc{'ies' if n_deps > 1 else 'y'}:" + f" {', '.join(deps_names)}" + ) + logger.info("discovered:\n%s", discovered) + return wfs, acts, deps + + +def discover_workflows(names: list[str]) -> Iterable[_RegisteredWorkflow]: pattern = None if not names else re.compile(rf"^{'|'.join(names)}$") impls = entry_points(group=_WORKFLOW_GROUPS) for wf_impls in impls: @@ -24,7 +77,7 @@ def discover_workflows(names: list[str]) -> Iterable[tuple[str, type]]: yield wf_name, wf_impl -def discover_activities(names: list[str]) -> Iterable[tuple[str, Activity]]: +def discover_activities(names: list[str]) -> Iterable[_RegisteredActivity]: pattern = None if not names else re.compile(rf"^{'|'.join(names)}$") impls = entry_points(group=_ACTIVITIES_GROUPS) for act_impls in impls: @@ -38,6 +91,43 @@ def discover_activities(names: list[str]) -> Iterable[tuple[str, Activity]]: yield act_name, act_impl +def discover_dependencies(name: str | None) -> _Dependencies | None: + impls = entry_points(name=_DEPENDENCIES, group=_DEPENDENCIES_GROUPS) + if not impls: + if name is None: + return None + available_impls = entry_points(group=_DEPENDENCIES_GROUPS) + msg = ( + f'failed to find dependency: "{name}", ' + f"available dependencies: {available_impls}" + ) + raise LookupError(msg) + if len(impls) > 1: + msg = f'found multiple dependencies for name "{name}": {impls}' + raise ValueError(msg) + deps_registry = impls[_DEPENDENCIES].load() + if name: + try: + return deps_registry[name] + except KeyError as e: + available = list(deps_registry) + msg = ( + f'failed to find dependency for name "{name}", available dependencies: ' + f"{available}" + ) + raise LookupError(msg) from e + if not deps_registry: + raise ValueError("empty dependency registry !") + if len(deps_registry) > 1: + available = ", ".join('"' + d + '"' for d in deps_registry) + msg = ( + f"dependency registry contains multiples entries {available}," + f" please select one by providing a name" + ) + raise ValueError(msg) + return next(iter(deps_registry.values())) + + def _parse_wf_name(wf_type: type) -> str: if not isinstance(wf_type, type): msg = ( diff --git a/datashare-python/datashare_python/worker.py b/datashare-python/datashare_python/worker.py index 9a521a8..fc8e8c6 100644 --- a/datashare-python/datashare_python/worker.py +++ b/datashare-python/datashare_python/worker.py @@ -1,4 +1,10 @@ +import inspect import logging +import os +import socket +import threading +from asyncio import AbstractEventLoop +from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor from temporalio.worker import PollerBehaviorSimpleMaximum, Worker @@ -30,6 +36,7 @@ def datashare_worker( client: TemporalClient, + worker_id: str, *, workflows: list[type] | None = None, activities: list[Activity] | None = None, @@ -60,6 +67,7 @@ def datashare_worker( return Worker( client, + identity=worker_id, workflows=workflows, activities=activities, task_queue=task_queue, @@ -72,3 +80,40 @@ def datashare_worker( # several of them workflow_task_poller_behavior=PollerBehaviorSimpleMaximum(5), ) + + +def create_worker_id(prefix: str) -> str: + pid = os.getpid() + threadid = threading.get_ident() + hostname = socket.gethostname() + # TODO: this might not be unique when using asyncio + return f"{prefix}-{hostname}-{pid}-{threadid}" + + +_CLIENT = "client" +_WORKER = "worker" +_EXPECTED_INIT_ARGS = {"self", _CLIENT, _WORKER} + + +def _get_class_from_method(method: Callable) -> type: + qualname = method.__qualname__ + class_name = qualname.rsplit(".", 1)[0] + return method.__globals__.get(class_name) + + +def init_activity( + activity: Callable, client: TemporalClient, event_loop: AbstractEventLoop +) -> Callable: + if not inspect.ismethod(activity): + return activity + cls = _get_class_from_method(activity) + init_args = inspect.signature(cls.__init__).parameters + invalid = [p for p in init_args if p not in _EXPECTED_INIT_ARGS] + if invalid: + msg = f"invalid activity arguments: {invalid}" + raise ValueError(msg) + kwargs = {"client": client, "event_loop": event_loop} + kwargs = {k: v for k, v in kwargs.items() if k in _EXPECTED_INIT_ARGS} + if not kwargs: + return activity + return cls(**kwargs) diff --git a/datashare-python/pyproject.toml b/datashare-python/pyproject.toml index 5e8285a..e66ade7 100644 --- a/datashare-python/pyproject.toml +++ b/datashare-python/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "datashare-worker-template~=0.1", "tomlkit~=0.14.0", "hatchling~=1.27.0", + "pyyaml~=6.0", ] [project.urls] diff --git a/datashare-python/tests/cli/test_worker.py b/datashare-python/tests/cli/test_worker.py index 9d195d8..b426cd9 100644 --- a/datashare-python/tests/cli/test_worker.py +++ b/datashare-python/tests/cli/test_worker.py @@ -1,3 +1,5 @@ +from pathlib import Path + from _pytest.capture import CaptureFixture from _pytest.monkeypatch import MonkeyPatch from datashare_python.cli import cli_app @@ -10,12 +12,13 @@ async def _mock_worker_run(self) -> None: # noqa: ANN001 async def test_start_workers( - worker_lifetime_deps, # noqa: ANN001, ARG001 typer_asyncio_patch, # noqa: ANN001, ARG001 + test_worker_config_path: Path, monkeypatch: MonkeyPatch, capsys: CaptureFixture[str], ) -> None: # Given + config_path = test_worker_config_path runner = CliRunner(mix_stderr=False) monkeypatch.setattr(Worker, "run", _mock_worker_run) with capsys.disabled(): @@ -27,6 +30,8 @@ async def test_start_workers( "start", "--queue", "cpu", + "-c", + str(config_path), "--activities", "ping", "--activities", @@ -40,7 +45,8 @@ async def test_start_workers( ) # Then assert result.exit_code == 0 - expected = """Starting datashare worker running: + expected = """discovered: - 1 workflow: ping -- 1 activity: create-translation-batches""" +- 1 activity: create-translation-batches +- 4 dependencies: set_loggers, set_event_loop, set_es_client, set_temporal_client""" assert expected in result.stderr diff --git a/datashare-python/tests/test_discovery.py b/datashare-python/tests/test_discovery.py index 47132b5..871a0fc 100644 --- a/datashare-python/tests/test_discovery.py +++ b/datashare-python/tests/test_discovery.py @@ -1,5 +1,15 @@ +import re +from importlib.metadata import EntryPoints +from unittest.mock import MagicMock + +import datashare_python import pytest -from datashare_python.discovery import discover_activities, discover_workflows +from _pytest.monkeypatch import MonkeyPatch +from datashare_python.discovery import ( + discover_activities, + discover_dependencies, + discover_workflows, +) @pytest.mark.parametrize( @@ -42,3 +52,66 @@ def test_discover_activities(names: list[str], expected_activities: set[str]) -> activities = {act_name for act_name, _ in discover_activities(names)} # Then assert activities == expected_activities + + +@pytest.mark.parametrize("name", ["base", None]) +def test_discover_dependencies(name: str | None) -> None: + # When + deps = discover_dependencies(name) + # Then + expected_deps = [ + "set_loggers", + "set_event_loop", + "set_es_client", + "set_temporal_client", + ] + assert [d.__name__ for d in deps] == expected_deps + + +def test_discover_dependencies_should_raise_for_unknown_dep() -> None: + # Given + unknown_dep = "unknown_dep" + # When/Then + expected = ( + 'failed to find dependency for name "unknown_dep", ' + "available dependencies: ['base']" + ) + with pytest.raises(LookupError, match=re.escape(expected)): + discover_dependencies(unknown_dep) + + +def test_discover_dependencies_should_raise_for_conflicting_deps( + monkeypatch: MonkeyPatch, +) -> None: + # Given + def mocked_entry_points(name: str, group: str) -> EntryPoints: # noqa: ARG001 + entry_points = MagicMock() + entry_points.__len__.return_value = 2 + return entry_points + + monkeypatch.setattr(datashare_python.discovery, "entry_points", mocked_entry_points) + # When/Then + expected = "found multiple dependencies for name" + with pytest.raises(ValueError, match=re.escape(expected)): + discover_dependencies(name=None) + + +def test_discover_dependencies_should_raise_for_multiple_entry_points( + monkeypatch: MonkeyPatch, +) -> None: + # Given + def mocked_entry_points(name: str, group: str) -> EntryPoints: # noqa: ARG001 + ep = MagicMock() + ep.load.return_value = {"a": [], "b": []} + entry_points = MagicMock() + entry_points.__getitem__.return_value = ep + return entry_points + + monkeypatch.setattr(datashare_python.discovery, "entry_points", mocked_entry_points) + # When/Then + expected = ( + 'dependency registry contains multiples entries "a", "b",' + " please select one by providing a name" + ) + with pytest.raises(ValueError, match=re.escape(expected)): + discover_dependencies(name=None) diff --git a/datashare-python/uv.lock b/datashare-python/uv.lock index 5610748..184c0d9 100644 --- a/datashare-python/uv.lock +++ b/datashare-python/uv.lock @@ -304,7 +304,7 @@ wheels = [ [[package]] name = "datashare-python" -version = "0.2.12" +version = "0.2.23" source = { editable = "." } dependencies = [ { name = "aiohttp" }, @@ -315,6 +315,7 @@ dependencies = [ { name = "icij-common", extra = ["elasticsearch"] }, { name = "nest-asyncio" }, { name = "python-json-logger" }, + { name = "pyyaml" }, { name = "temporalio" }, { name = "tomlkit" }, { name = "typer" }, @@ -344,6 +345,7 @@ requires-dist = [ { name = "icij-common", extras = ["elasticsearch"], specifier = "~=0.7.3" }, { name = "nest-asyncio", specifier = "~=1.6.0" }, { name = "python-json-logger", specifier = "~=4.0.0" }, + { name = "pyyaml", specifier = "~=6.0" }, { name = "temporalio", specifier = "~=1.23.0" }, { name = "tomlkit", specifier = "~=0.14.0" }, { name = "typer", specifier = "~=0.15.4" }, @@ -365,7 +367,7 @@ dev = [ [[package]] name = "datashare-worker-template" -version = "0.1.2" +version = "0.1.5" source = { editable = "../worker-template" } dependencies = [ { name = "datashare-python" }, diff --git a/worker-template/pyproject.toml b/worker-template/pyproject.toml index 2264de4..05b9bf3 100644 --- a/worker-template/pyproject.toml +++ b/worker-template/pyproject.toml @@ -28,6 +28,9 @@ workflows = "worker_template.workflows:WORKFLOWS" [project.entry-points."datashare.activities"] activities = "worker_template.activities:ACTIVITIES" +[project.entry-points."datashare.dependencies"] +dependencies = "worker_template.dependencies:DEPENDENCIES" + [tool.uv.sources] torch = [ { index = "pytorch-cpu" }, diff --git a/worker-template/worker_template/activities.py b/worker-template/worker_template/activities.py index c773cdd..01a12e3 100644 --- a/worker-template/worker_template/activities.py +++ b/worker-template/worker_template/activities.py @@ -2,6 +2,7 @@ import logging from collections.abc import AsyncGenerator, Generator, Iterable from functools import partial +from typing import TYPE_CHECKING from aiostream.stream import chain from datashare_python.objects import Document @@ -36,7 +37,9 @@ ) from temporalio import activity from temporalio.client import Client -from transformers import Pipeline, pipeline + +if TYPE_CHECKING: + from transformers import Pipeline from .objects_ import ClassificationConfig, TranslationConfig from .utils_ import async_batches, batches, before_and_after, once @@ -233,6 +236,7 @@ async def translate_docs( config: TranslationConfig | None = None, ) -> int: import torch # noqa:PLC0415 + from transformers import pipeline # noqa: PLC0415 if config is None: config = TranslationConfig() @@ -298,6 +302,7 @@ async def classify_docs( es_client: ESClient, ) -> int: import torch # noqa: PLC0415 + from transformers import pipeline # noqa: PLC0415 if config is None: config = ClassificationConfig() @@ -354,11 +359,11 @@ async def classify_docs( return n_docs -def _translate_as_list(pipe: Pipeline, texts: list[str]) -> list[str]: +def _translate_as_list(pipe: "Pipeline", texts: list[str]) -> list[str]: return list(_translate(pipe, texts)) -def _translate(pipe: Pipeline, texts: list[str]) -> Generator[str, None, None]: +def _translate(pipe: "Pipeline", texts: list[str]) -> Generator[str, None, None]: for res in pipe(texts): yield res["translation_text"] @@ -452,13 +457,13 @@ async def _count_untranslated( return res[COUNT] -def _classify(pipe: Pipeline, texts: list[str]) -> Generator[str, None, None]: +def _classify(pipe: "Pipeline", texts: list[str]) -> Generator[str, None, None]: # In practice, we should chunk the text for res in pipe(texts, padding=True, truncation=True): yield res["label"] -def _classify_as_list(pipe: Pipeline, texts: list[str]) -> list[str]: +def _classify_as_list(pipe: "Pipeline", texts: list[str]) -> list[str]: return list(_classify(pipe, texts)) diff --git a/worker-template/worker_template/dependencies.py b/worker-template/worker_template/dependencies.py new file mode 100644 index 0000000..64dd6d3 --- /dev/null +++ b/worker-template/worker_template/dependencies.py @@ -0,0 +1,10 @@ +from datashare_python.dependencies import ( + set_es_client, + set_event_loop, + set_loggers, + set_temporal_client, +) + +BASE = [set_loggers, set_event_loop, set_es_client, set_temporal_client] + +DEPENDENCIES = {"base": BASE}