Skip to content
10 changes: 9 additions & 1 deletion src/blueapi/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def on_event(
@controller.command(name="run")
@click.argument("name", type=str)
@click.argument("parameters", type=ParametersType(), default={}, required=False)
@click.option("--ws", type=bool, is_flag=True, default=False)
@click.option(
"--foreground/--background", "--fg/--bg", type=bool, is_flag=True, default=True
)
Expand Down Expand Up @@ -314,6 +315,7 @@ def run_plan(
name: str,
timeout: float | None,
foreground: bool,
ws: bool,
instrument_session: str,
parameters: TaskParameters,
) -> None:
Expand All @@ -335,7 +337,13 @@ def on_event(event: AnyEvent) -> None:
elif isinstance(event, DataEvent):
callback(event.name, event.doc)

resp = client.run_task(task, on_event=on_event)
client.add_callback(on_event)

if ws:
resp = client.run_blocking(task)
else:
resp = client.run_task(task)

match resp.result:
case TaskResult(result=None, type="NoneType"):
print("Plan succeeded")
Expand Down
20 changes: 20 additions & 0 deletions src/blueapi/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,26 @@ def get_active_task(self) -> WorkerTask:

return self.active_task

@start_as_current_span(TRACER, "request")
def run_blocking(
self, request: TaskRequest, on_event: OnAnyEvent | None = None
) -> TaskStatus:
for event in self._rest.run_blocking(request):
if on_event is not None:
on_event(event)
for cb in self._callbacks.values():
try:
cb(event)
except Exception as e:
log.error(f"Callback ({cb}) failed for event: {event}", exc_info=e)
if isinstance(event, WorkerEvent) and event.is_complete():
if event.task_status is None:
raise BlueskyRemoteControlError(
"Server completed without task status"
)
return event.task_status
raise BlueskyRemoteControlError("Connection closed before plan completed.")

@start_as_current_span(TRACER, "task", "timeout")
def run_task(
self,
Expand Down
18 changes: 17 additions & 1 deletion src/blueapi/client/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
get_tracer,
start_as_current_span,
)
from pydantic import BaseModel, TypeAdapter, ValidationError
from pydantic import BaseModel, TypeAdapter, ValidationError, WebsocketUrl
from websockets.sync.client import connect

from blueapi.client.event_bus import AnyEvent
from blueapi.config import RestConfig
from blueapi.service.authentication import JWTAuth, SessionManager
from blueapi.service.model import (
Expand Down Expand Up @@ -274,6 +276,20 @@ def _request_and_deserialize(
deserialized = TypeAdapter(target_type).validate_python(response.json())
return deserialized

def run_blocking(self, req: TaskRequest):
url = self._ws_address().unicode_string().removesuffix("/") + "/run_plan"
with connect(url) as ws:
ws.send(req.model_dump_json())
for message in ws:
event = TypeAdapter(AnyEvent).validate_json(message)
yield event

def _ws_address(self) -> WebsocketUrl:
# url = WebsocketUrl.build(
# scheme="ws", host=api.host, port=api.port, path=api.path
# )
return WebsocketUrl("ws://localhost:8000/")


# https://github.com/DiamondLightSource/blueapi/issues/1256 - remove before 2.0
def __getattr__(name: str):
Expand Down
33 changes: 31 additions & 2 deletions src/blueapi/service/interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from collections.abc import Mapping
from functools import cache
from multiprocessing.connection import Connection
from typing import Any

from bluesky.callbacks.tiled_writer import TiledWriter
Expand All @@ -9,6 +11,7 @@

from blueapi.cli.scratch import get_python_environment
from blueapi.config import ApplicationConfig, OIDCConfig, StompConfig
from blueapi.core.bluesky_types import DataEvent
from blueapi.core.context import BlueskyContext
from blueapi.core.event import EventStream
from blueapi.log import set_up_logging
Expand All @@ -21,14 +24,14 @@
WorkerTask,
)
from blueapi.utils.serialization import access_blob
from blueapi.worker.event import TaskStatusEnum, WorkerEvent, WorkerState
from blueapi.worker.event import ProgressEvent, TaskStatusEnum, WorkerEvent, WorkerState
from blueapi.worker.task import Task
from blueapi.worker.task_worker import TaskWorker, TrackableTask

"""This module provides interface between web application and underlying Bluesky
context and worker"""


LOGGER = logging.getLogger(__name__)
_CONFIG: ApplicationConfig = ApplicationConfig()


Expand Down Expand Up @@ -270,3 +273,29 @@ def get_python_env(
"""Retrieve information about the Python environment"""
scratch = config().scratch
return get_python_environment(config=scratch, name=name, source=source)


SubHandle = tuple[int, int, int]


def pipe_events(tx: Connection) -> SubHandle:

def handler(
worker_event: WorkerEvent | DataEvent | ProgressEvent,
_cor_id: str | None,
) -> None:
tx.send(worker_event)

task_worker = worker()
w_id = task_worker.worker_events.subscribe(handler)
d_id = task_worker.data_events.subscribe(handler)
p_id = task_worker.progress_events.subscribe(handler)
return (w_id, d_id, p_id)


def unpipe_events(hnd: SubHandle) -> None:
task_worker = worker()
w, d, p = hnd
task_worker.worker_events.unsubscribe(w)
task_worker.data_events.unsubscribe(d)
task_worker.progress_events.unsubscribe(p)
57 changes: 56 additions & 1 deletion src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import urllib.parse
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
from multiprocessing import Pipe
from typing import Annotated, Any

import jwt
Expand All @@ -14,8 +15,10 @@
HTTPException,
Request,
Response,
WebSocket,
status,
)
from fastapi.concurrency import run_in_threadpool
from fastapi.datastructures import Address
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse, StreamingResponse
Expand All @@ -35,9 +38,11 @@
from super_state_machine.errors import TransitionError

from blueapi.config import ApplicationConfig, OIDCConfig, Tag
from blueapi.core.bluesky_types import DataEvent
from blueapi.service import interface
from blueapi.worker import TrackableTask, WorkerState
from blueapi.worker.event import TaskStatusEnum
from blueapi.worker.event import ProgressEvent, TaskStatusEnum, WorkerEvent
from blueapi.worker.worker_errors import WorkerBusyError

from .model import (
DeviceModel,
Expand All @@ -62,6 +67,9 @@
LOGGER = logging.getLogger(__name__)


AnyEvent = WorkerEvent | DataEvent | ProgressEvent


def _runner() -> WorkerDispatcher:
"""Intended to be used only with FastAPI Depends"""
if RUNNER is None:
Expand Down Expand Up @@ -540,6 +548,53 @@ def logout(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> Response:
)


@secure_router.websocket("/run_plan")
async def run_plan(
ws: WebSocket,
runner: Annotated[WorkerDispatcher, Depends(_runner)],
):
user = "alice"

# ack ws
await ws.accept()
# accept task request through socket
rq = await ws.receive_json()
# submit task to runner
try:
task_request: TaskRequest = TaskRequest.model_validate(rq)
task_id: str = runner.run(interface.submit_task, task_request, {"user": user})
except ValidationError:
await ws.close(code=1003, reason="invalid args")
return
except KeyError:
await ws.close(code=1003, reason="unknown plan")
return

# add listener to runner
tx, rx = Pipe()
h = runner.run(interface.pipe_events, tx=tx)
# start task
try:
task = WorkerTask(task_id=task_id)
runner.run(
interface.begin_task,
task=task,
)
except WorkerBusyError:
await ws.close(code=1013, reason="Worker busy")
return
# pipe events to ws
try:
while True:
event: AnyEvent = await run_in_threadpool(rx.recv)
await ws.send_json(event.model_dump(mode="json"))
if isinstance(event, WorkerEvent) and event.is_complete():
break
finally:
await ws.close()
runner.run(interface.unpipe_events, hnd=h)


@start_as_current_span(TRACER, "config")
def start(config: ApplicationConfig):
import uvicorn
Expand Down