Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 66 additions & 29 deletions datashare-python/datashare_python/cli/worker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import asyncio
import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Annotated

import typer
import yaml

from datashare_python.config import WorkerConfig
from datashare_python.constants import DEFAULT_NAMESPACE, DEFAULT_TEMPORAL_ADDRESS
from datashare_python.discovery import discover_activities, discover_workflows
from datashare_python.dependencies import with_dependencies
from datashare_python.discovery import discover, discover_activities, discover_workflows
from datashare_python.types_ import TemporalClient
from datashare_python.worker import datashare_worker
from datashare_python.worker import create_worker_id, datashare_worker, init_activity

from .utils import AsyncTyper

Expand All @@ -20,6 +27,12 @@

_START_WORKER_WORKFLOWS_HELP = "workflow names run by the worker (supports regexes)"
_START_WORKER_ACTIVITIES_HELP = "activity names run by the worker (supports regexes)"
_START_WORKER_DEPS_HELP = "worker lifetime dependencies name in the registry"
_START_WORKER_WORKER_ID_PREFIX_HELP = "worker ID prefix"
_START_WORKER_CONFIG_PATH_HELP = (
"path to a worker config YAML file,"
" if not provided will load worker configuration from env variables"
)
_WORKER_QUEUE_HELP = "worker task queue"
_WORKER_MAX_ACTIVITIES_HELP = (
"maximum number of concurrent activities/tasks"
Expand Down Expand Up @@ -73,6 +86,18 @@ async def start(
workflows: Annotated[list[str], typer.Option(help=_START_WORKER_WORKFLOWS_HELP)],
activities: Annotated[list[str], typer.Option(help=_START_WORKER_ACTIVITIES_HELP)],
queue: Annotated[str, typer.Option("--queue", "-q", help=_WORKER_QUEUE_HELP)],
dependencies: Annotated[
str | None, typer.Option(help=_START_WORKER_DEPS_HELP)
] = None,
config_path: Annotated[
Path | None,
typer.Option(
"--config-path", "--config", "-c", help=_START_WORKER_CONFIG_PATH_HELP
),
] = None,
worker_id_prefix: Annotated[
str | None, typer.Option(help=_START_WORKER_WORKER_ID_PREFIX_HELP)
] = None,
temporal_address: Annotated[
str, typer.Option("--temporal-address", "-a", help=_TEMPORAL_URL_HELP)
] = DEFAULT_TEMPORAL_ADDRESS,
Expand All @@ -83,32 +108,44 @@ async def start(
int, typer.Option("--max-activities", help=_WORKER_MAX_ACTIVITIES_HELP)
] = 1,
) -> None:
wf_names, wfs = zip(*discover_workflows(workflows), strict=False)
registered = ""
if wf_names:
n_wfs = len(wf_names)
registered += (
f"- {n_wfs} workflow{'s' if n_wfs > 1 else ''}: {','.join(wf_names)}"
if config_path is not None:
with config_path.open() as f:
bootstrap_config = WorkerConfig.model_validate(
yaml.load(f, Loader=yaml.Loader)
)
else:
bootstrap_config = WorkerConfig()
wfs, acts, deps = discover(workflows, act_names=activities, deps_name=dependencies)
worker_id = create_worker_id(worker_id_prefix or "worker")
event_loop = asyncio.get_event_loop()
deps_cm = (
with_dependencies(
deps,
worker_config=bootstrap_config,
worker_id=worker_id,
event_loop=event_loop,
)
act_names, acts = zip(*discover_activities(activities), strict=False)
if act_names:
if registered:
registered += "\n"
i = len(act_names)
registered += f"- {i} activit{'ies' if i > 1 else 'y'}: {','.join(act_names)}"
if not acts and not wfs:
raise ValueError("Couldn't find any registered activity or workflow.")
logger.info("Starting datashare worker running:\n%s", registered)
client = await TemporalClient.connect(temporal_address, namespace=namespace)
worker = datashare_worker(
client,
workflows=wfs,
activities=acts,
task_queue=queue,
max_concurrent_activities=max_concurrent_activities,
if deps
else _do_nothing_cm
)
try:
await worker.run()
except Exception as e: # noqa: BLE001
await worker.shutdown()
raise e
async with deps_cm:
client = await TemporalClient.connect(temporal_address, namespace=namespace)
acts = [init_activity(a, client=client, event_loop=event_loop) for a in acts]
worker = datashare_worker(
client,
worker_id,
workflows=wfs,
activities=acts,
task_queue=queue,
max_concurrent_activities=max_concurrent_activities,
)
try:
await worker.run()
except Exception as e: # noqa: BLE001
await worker.shutdown()
raise e


@asynccontextmanager
async def _do_nothing_cm() -> AsyncGenerator[None, None]:
yield
8 changes: 8 additions & 0 deletions datashare-python/datashare_python/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from asyncio import AbstractEventLoop
from collections.abc import AsyncGenerator, Generator, Iterator, Sequence
from pathlib import Path

import aiohttp
import pytest
Expand Down Expand Up @@ -102,6 +103,13 @@ def test_worker_config() -> WorkerConfig:
)


@pytest.fixture
def test_worker_config_path(test_worker_config: WorkerConfig, tmpdir: Path) -> Path:
config_path = Path(tmpdir) / "config.json"
config_path.write_text(test_worker_config.model_dump_json())
return config_path


@pytest.fixture(scope="session")
async def worker_lifetime_deps(
event_loop: AbstractEventLoop,
Expand Down
94 changes: 92 additions & 2 deletions datashare-python/datashare_python/discovery.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,69 @@
import logging
import re
from collections.abc import Callable, Iterable
from importlib.metadata import entry_points

from .types_ import ContextManagerFactory
from .utils import ActivityWithProgress

logger = logging.getLogger(__name__)

Activity = ActivityWithProgress | Callable | type

_DEPENDENCIES = "dependencies"
_WORKFLOW_GROUPS = "datashare.workflows"
_ACTIVITIES_GROUPS = "datashare.activities"
_DEPENDENCIES_GROUPS = "datashare.dependencies"

_RegisteredWorkflow = tuple[str, type]
_RegisteredActivity = tuple[str, Activity]
_Dependencies = list[ContextManagerFactory]
_Discovery = tuple[
Iterable[_RegisteredWorkflow] | None,
Iterable[_RegisteredActivity] | None,
_Dependencies | None,
]


def discover_workflows(names: list[str]) -> Iterable[tuple[str, type]]:
def discover(
wf_names: list[str] | None, *, act_names: list[str] | None, deps_name: str | None
) -> _Discovery:
discovered = ""
wfs = None
if wf_names is not None:
wf_names, wfs = zip(*discover_workflows(wf_names), strict=True)
if wf_names:
n_wfs = len(wf_names)
discovered += (
f"- {n_wfs} workflow{'s' if n_wfs > 1 else ''}: {', '.join(wf_names)}"
)
acts = None
if act_names is not None:
act_names, acts = zip(*discover_activities(act_names), strict=True)
if act_names:
if discovered:
discovered += "\n"
n_acts = len(act_names)
discovered += (
f"- {n_acts} activit{'ies' if n_acts > 1 else 'y'}:"
f" {', '.join(act_names)}"
)
if not acts and not wfs:
raise ValueError("Couldn't find any registered activity or workflow.")
deps = discover_dependencies(deps_name)
if deps:
n_deps = len(deps)
discovered += "\n"
deps_names = (d.__name__ for d in deps)
discovered += (
f"- {n_deps} dependenc{'ies' if n_deps > 1 else 'y'}:"
f" {', '.join(deps_names)}"
)
logger.info("discovered:\n%s", discovered)
return wfs, acts, deps


def discover_workflows(names: list[str]) -> Iterable[_RegisteredWorkflow]:
pattern = None if not names else re.compile(rf"^{'|'.join(names)}$")
impls = entry_points(group=_WORKFLOW_GROUPS)
for wf_impls in impls:
Expand All @@ -24,7 +77,7 @@ def discover_workflows(names: list[str]) -> Iterable[tuple[str, type]]:
yield wf_name, wf_impl


def discover_activities(names: list[str]) -> Iterable[tuple[str, Activity]]:
def discover_activities(names: list[str]) -> Iterable[_RegisteredActivity]:
pattern = None if not names else re.compile(rf"^{'|'.join(names)}$")
impls = entry_points(group=_ACTIVITIES_GROUPS)
for act_impls in impls:
Expand All @@ -38,6 +91,43 @@ def discover_activities(names: list[str]) -> Iterable[tuple[str, Activity]]:
yield act_name, act_impl


def discover_dependencies(name: str | None) -> _Dependencies | None:
impls = entry_points(name=_DEPENDENCIES, group=_DEPENDENCIES_GROUPS)
if not impls:
if name is None:
return None
available_impls = entry_points(group=_DEPENDENCIES_GROUPS)
msg = (
f'failed to find dependency: "{name}", '
f"available dependencies: {available_impls}"
)
raise LookupError(msg)
if len(impls) > 1:
msg = f'found multiple dependencies for name "{name}": {impls}'
raise ValueError(msg)
deps_registry = impls[_DEPENDENCIES].load()
if name:
try:
return deps_registry[name]
except KeyError as e:
available = list(deps_registry)
msg = (
f'failed to find dependency for name "{name}", available dependencies: '
f"{available}"
)
raise LookupError(msg) from e
if not deps_registry:
raise ValueError("empty dependency registry !")
if len(deps_registry) > 1:
available = ", ".join('"' + d + '"' for d in deps_registry)
msg = (
f"dependency registry contains multiples entries {available},"
f" please select one by providing a name"
)
raise ValueError(msg)
return next(iter(deps_registry.values()))


def _parse_wf_name(wf_type: type) -> str:
if not isinstance(wf_type, type):
msg = (
Expand Down
45 changes: 45 additions & 0 deletions datashare-python/datashare_python/worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import inspect
import logging
import os
import socket
import threading
from asyncio import AbstractEventLoop
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor

from temporalio.worker import PollerBehaviorSimpleMaximum, Worker
Expand Down Expand Up @@ -30,6 +36,7 @@

def datashare_worker(
client: TemporalClient,
worker_id: str,
*,
workflows: list[type] | None = None,
activities: list[Activity] | None = None,
Expand Down Expand Up @@ -60,6 +67,7 @@ def datashare_worker(

return Worker(
client,
identity=worker_id,
workflows=workflows,
activities=activities,
task_queue=task_queue,
Expand All @@ -72,3 +80,40 @@ def datashare_worker(
# several of them
workflow_task_poller_behavior=PollerBehaviorSimpleMaximum(5),
)


def create_worker_id(prefix: str) -> str:
pid = os.getpid()
threadid = threading.get_ident()
hostname = socket.gethostname()
# TODO: this might not be unique when using asyncio
return f"{prefix}-{hostname}-{pid}-{threadid}"


_CLIENT = "client"
_WORKER = "worker"
_EXPECTED_INIT_ARGS = {"self", _CLIENT, _WORKER}


def _get_class_from_method(method: Callable) -> type:
qualname = method.__qualname__
class_name = qualname.rsplit(".", 1)[0]
return method.__globals__.get(class_name)


def init_activity(
activity: Callable, client: TemporalClient, event_loop: AbstractEventLoop
) -> Callable:
if not inspect.ismethod(activity):
return activity
cls = _get_class_from_method(activity)
init_args = inspect.signature(cls.__init__).parameters
invalid = [p for p in init_args if p not in _EXPECTED_INIT_ARGS]
if invalid:
msg = f"invalid activity arguments: {invalid}"
raise ValueError(msg)
kwargs = {"client": client, "event_loop": event_loop}
kwargs = {k: v for k, v in kwargs.items() if k in _EXPECTED_INIT_ARGS}
if not kwargs:
return activity
return cls(**kwargs)
1 change: 1 addition & 0 deletions datashare-python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"datashare-worker-template~=0.1",
"tomlkit~=0.14.0",
"hatchling~=1.27.0",
"pyyaml~=6.0",
]

[project.urls]
Expand Down
12 changes: 9 additions & 3 deletions datashare-python/tests/cli/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

from _pytest.capture import CaptureFixture
from _pytest.monkeypatch import MonkeyPatch
from datashare_python.cli import cli_app
Expand All @@ -10,12 +12,13 @@ async def _mock_worker_run(self) -> None: # noqa: ANN001


async def test_start_workers(
worker_lifetime_deps, # noqa: ANN001, ARG001
typer_asyncio_patch, # noqa: ANN001, ARG001
test_worker_config_path: Path,
monkeypatch: MonkeyPatch,
capsys: CaptureFixture[str],
) -> None:
# Given
config_path = test_worker_config_path
runner = CliRunner(mix_stderr=False)
monkeypatch.setattr(Worker, "run", _mock_worker_run)
with capsys.disabled():
Expand All @@ -27,6 +30,8 @@ async def test_start_workers(
"start",
"--queue",
"cpu",
"-c",
str(config_path),
"--activities",
"ping",
"--activities",
Expand All @@ -40,7 +45,8 @@ async def test_start_workers(
)
# Then
assert result.exit_code == 0
expected = """Starting datashare worker running:
expected = """discovered:
- 1 workflow: ping
- 1 activity: create-translation-batches"""
- 1 activity: create-translation-batches
- 4 dependencies: set_loggers, set_event_loop, set_es_client, set_temporal_client"""
assert expected in result.stderr
Loading
Loading