diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 7bafa831f..d0f2bcf1c 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -287,6 +287,7 @@ def on_event( @controller.command(name="run") @click.argument("name", type=str) @click.argument("parameters", type=ParametersType(), default={}, required=False) +@click.option("--ws", type=bool, is_flag=True, default=False) @click.option( "--foreground/--background", "--fg/--bg", type=bool, is_flag=True, default=True ) @@ -314,6 +315,7 @@ def run_plan( name: str, timeout: float | None, foreground: bool, + ws: bool, instrument_session: str, parameters: TaskParameters, ) -> None: @@ -335,7 +337,13 @@ def on_event(event: AnyEvent) -> None: elif isinstance(event, DataEvent): callback(event.name, event.doc) - resp = client.run_task(task, on_event=on_event) + client.add_callback(on_event) + + if ws: + resp = client.run_blocking(task) + else: + resp = client.run_task(task) + match resp.result: case TaskResult(result=None, type="NoneType"): print("Plan succeeded") diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index c5b41ff45..fd2c93f53 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -439,6 +439,26 @@ def get_active_task(self) -> WorkerTask: return self.active_task + @start_as_current_span(TRACER, "request") + def run_blocking( + self, request: TaskRequest, on_event: OnAnyEvent | None = None + ) -> TaskStatus: + for event in self._rest.run_blocking(request): + if on_event is not None: + on_event(event) + for cb in self._callbacks.values(): + try: + cb(event) + except Exception as e: + log.error(f"Callback ({cb}) failed for event: {event}", exc_info=e) + if isinstance(event, WorkerEvent) and event.is_complete(): + if event.task_status is None: + raise BlueskyRemoteControlError( + "Server completed without task status" + ) + return event.task_status + raise BlueskyRemoteControlError("Connection closed before plan completed.") + @start_as_current_span(TRACER, "task", "timeout") def run_task( self, diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index 52150d36f..c2d161118 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -8,8 +8,10 @@ get_tracer, start_as_current_span, ) -from pydantic import BaseModel, TypeAdapter, ValidationError +from pydantic import BaseModel, TypeAdapter, ValidationError, WebsocketUrl +from websockets.sync.client import connect +from blueapi.client.event_bus import AnyEvent from blueapi.config import RestConfig from blueapi.service.authentication import JWTAuth, SessionManager from blueapi.service.model import ( @@ -274,6 +276,20 @@ def _request_and_deserialize( deserialized = TypeAdapter(target_type).validate_python(response.json()) return deserialized + def run_blocking(self, req: TaskRequest): + url = self._ws_address().unicode_string().removesuffix("/") + "/run_plan" + with connect(url) as ws: + ws.send(req.model_dump_json()) + for message in ws: + event = TypeAdapter(AnyEvent).validate_json(message) + yield event + + def _ws_address(self) -> WebsocketUrl: + # url = WebsocketUrl.build( + # scheme="ws", host=api.host, port=api.port, path=api.path + # ) + return WebsocketUrl("ws://localhost:8000/") + # https://github.com/DiamondLightSource/blueapi/issues/1256 - remove before 2.0 def __getattr__(name: str): diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 6acc29ab7..5c402492e 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -1,5 +1,7 @@ +import logging from collections.abc import Mapping from functools import cache +from multiprocessing.connection import Connection from typing import Any from bluesky.callbacks.tiled_writer import TiledWriter @@ -9,6 +11,7 @@ from blueapi.cli.scratch import get_python_environment from blueapi.config import ApplicationConfig, OIDCConfig, StompConfig +from blueapi.core.bluesky_types import DataEvent from blueapi.core.context import BlueskyContext from blueapi.core.event import EventStream from blueapi.log import set_up_logging @@ -21,14 +24,14 @@ WorkerTask, ) from blueapi.utils.serialization import access_blob -from blueapi.worker.event import TaskStatusEnum, WorkerEvent, WorkerState +from blueapi.worker.event import ProgressEvent, TaskStatusEnum, WorkerEvent, WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TaskWorker, TrackableTask """This module provides interface between web application and underlying Bluesky context and worker""" - +LOGGER = logging.getLogger(__name__) _CONFIG: ApplicationConfig = ApplicationConfig() @@ -270,3 +273,29 @@ def get_python_env( """Retrieve information about the Python environment""" scratch = config().scratch return get_python_environment(config=scratch, name=name, source=source) + + +SubHandle = tuple[int, int, int] + + +def pipe_events(tx: Connection) -> SubHandle: + + def handler( + worker_event: WorkerEvent | DataEvent | ProgressEvent, + _cor_id: str | None, + ) -> None: + tx.send(worker_event) + + task_worker = worker() + w_id = task_worker.worker_events.subscribe(handler) + d_id = task_worker.data_events.subscribe(handler) + p_id = task_worker.progress_events.subscribe(handler) + return (w_id, d_id, p_id) + + +def unpipe_events(hnd: SubHandle) -> None: + task_worker = worker() + w, d, p = hnd + task_worker.worker_events.unsubscribe(w) + task_worker.data_events.unsubscribe(d) + task_worker.progress_events.unsubscribe(p) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index c79dd3df3..72715ea29 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -2,6 +2,7 @@ import urllib.parse from collections.abc import Awaitable, Callable from contextlib import asynccontextmanager +from multiprocessing import Pipe from typing import Annotated, Any import jwt @@ -14,8 +15,10 @@ HTTPException, Request, Response, + WebSocket, status, ) +from fastapi.concurrency import run_in_threadpool from fastapi.datastructures import Address from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse, StreamingResponse @@ -35,9 +38,11 @@ from super_state_machine.errors import TransitionError from blueapi.config import ApplicationConfig, OIDCConfig, Tag +from blueapi.core.bluesky_types import DataEvent from blueapi.service import interface from blueapi.worker import TrackableTask, WorkerState -from blueapi.worker.event import TaskStatusEnum +from blueapi.worker.event import ProgressEvent, TaskStatusEnum, WorkerEvent +from blueapi.worker.worker_errors import WorkerBusyError from .model import ( DeviceModel, @@ -62,6 +67,9 @@ LOGGER = logging.getLogger(__name__) +AnyEvent = WorkerEvent | DataEvent | ProgressEvent + + def _runner() -> WorkerDispatcher: """Intended to be used only with FastAPI Depends""" if RUNNER is None: @@ -540,6 +548,53 @@ def logout(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> Response: ) +@secure_router.websocket("/run_plan") +async def run_plan( + ws: WebSocket, + runner: Annotated[WorkerDispatcher, Depends(_runner)], +): + user = "alice" + + # ack ws + await ws.accept() + # accept task request through socket + rq = await ws.receive_json() + # submit task to runner + try: + task_request: TaskRequest = TaskRequest.model_validate(rq) + task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) + except ValidationError: + await ws.close(code=1003, reason="invalid args") + return + except KeyError: + await ws.close(code=1003, reason="unknown plan") + return + + # add listener to runner + tx, rx = Pipe() + h = runner.run(interface.pipe_events, tx=tx) + # start task + try: + task = WorkerTask(task_id=task_id) + runner.run( + interface.begin_task, + task=task, + ) + except WorkerBusyError: + await ws.close(code=1013, reason="Worker busy") + return + # pipe events to ws + try: + while True: + event: AnyEvent = await run_in_threadpool(rx.recv) + await ws.send_json(event.model_dump(mode="json")) + if isinstance(event, WorkerEvent) and event.is_complete(): + break + finally: + await ws.close() + runner.run(interface.unpipe_events, hnd=h) + + @start_as_current_span(TRACER, "config") def start(config: ApplicationConfig): import uvicorn