diff --git a/python/fusion_engine_client/applications/p1_capture.py b/python/fusion_engine_client/applications/p1_capture.py index 8e545a6e..032216ef 100755 --- a/python/fusion_engine_client/applications/p1_capture.py +++ b/python/fusion_engine_client/applications/p1_capture.py @@ -187,8 +187,7 @@ def main(): _logger.error(f'--log-timestamp-source={options.log_timestamp_source} is not supported. Only "user-sw" timestamps are supported on non-socket captures.') sys.exit(1) - if isinstance(transport, serial.Serial): - transport.timeout = read_timeout_sec + set_read_timeout(transport, read_timeout_sec) # Listen for incoming data. decoder = FusionEngineDecoder(warn_on_unrecognized=not options.quiet and not options.summary, return_bytes=True) @@ -227,7 +226,7 @@ def _print_status(now): received_data = [] # If this is a serial port or file, we set the read timeout above. else: - received_data = transport.read(1024) + received_data = recv_from_transport(transport, 1024) bytes_received += len(received_data) diff --git a/python/fusion_engine_client/applications/p1_filter.py b/python/fusion_engine_client/applications/p1_filter.py index 5e9b2bb3..03416908 100755 --- a/python/fusion_engine_client/applications/p1_filter.py +++ b/python/fusion_engine_client/applications/p1_filter.py @@ -164,10 +164,7 @@ def main(): while True: # Need to specify read size or read waits for end of file character. # This returns immediately even if 0 bytes are available. - if isinstance(input_transport, socket.socket): - received_data = input_transport.recv(64) - else: - received_data = input_transport.read(64) + received_data = recv_from_transport(input_transport, 64) if len(received_data) == 0: time.sleep(0.1) diff --git a/python/fusion_engine_client/utils/socket_timestamping.py b/python/fusion_engine_client/utils/socket_timestamping.py index 7e6a2923..c768e8aa 100755 --- a/python/fusion_engine_client/utils/socket_timestamping.py +++ b/python/fusion_engine_client/utils/socket_timestamping.py @@ -11,7 +11,7 @@ import socket import struct import sys -from typing import BinaryIO, Optional, Tuple, TypeAlias +from typing import BinaryIO, Optional, Tuple, TypeAlias, Union _CMSG: TypeAlias = tuple[int, int, bytes] @@ -127,7 +127,7 @@ def enable_socket_timestamping(sock: socket.socket, enable_sw_timestamp: bool, e return False -def recv(sock: socket.socket, buffer_size: int) -> Tuple[bytes, Optional[float], Optional[float]]: +def recv(sock: Union[socket.socket, BinaryIO], buffer_size: int) -> Tuple[bytes, Optional[float], Optional[float]]: '''! Receive data from the specified socket and capture timestamps, if enabled. @@ -139,7 +139,12 @@ def recv(sock: socket.socket, buffer_size: int) -> Tuple[bytes, Optional[float], - The kernel timestamp, if enabled - The hardware timestamp, if enabled ''' - if sys.platform == "linux": + # Handle non-sockets (websocket, BinaryIO (file), etc.) gracefully. + if not isinstance(sock, socket.socket): + received_data = sock.read(buffer_size) + kernel_ts = None + hw_ts = None + elif sys.platform == "linux": received_data, ancdata, _, _ = sock.recvmsg(buffer_size, 1024) kernel_ts, _, hw_ts = parse_timestamps_from_ancdata(ancdata) else: diff --git a/python/fusion_engine_client/utils/transport_utils.py b/python/fusion_engine_client/utils/transport_utils.py index 677059f7..2723bde4 100644 --- a/python/fusion_engine_client/utils/transport_utils.py +++ b/python/fusion_engine_client/utils/transport_utils.py @@ -1,7 +1,7 @@ import re import socket import sys -from typing import BinaryIO, Callable, TextIO, Union +from typing import Any, BinaryIO, Callable, TextIO, Union # WebSocket support is optional. To use, install with: # pip install websockets @@ -142,6 +142,61 @@ def write(self, data: Union[bytes, bytearray]) -> int: raise RuntimeError('Output file not opened.') +class WebsocketTransport: + """! + @brief Websocket wrapper class, mimicking the Python socket API. + + This class defers all function calls and attribute to the underlying `ws.ClientConnection` websocket instance. Any + function defined for `ClientConnection` should work on this class (e.g., `close()`). + """ + + def __init__(self, *args, **kwargs): + # Note: Omitting "_sec" from argument name for consistent with connect() arguments. + self._read_timeout_sec = kwargs.pop('read_timeout', None) + + self._websocket = kwargs.pop('websocket', None) + if self._websocket is None: + self._websocket = ws.connect(*args, **kwargs) + + def set_timeout(self, timeout_sec: float): + if timeout_sec < 0.0: + self._read_timeout_sec = None + else: + self._read_timeout_sec = timeout_sec + + def recv(self, unused_size_bytes: int = None) -> bytes: + """! + @brief Receive data from the websocket. + + @note + This function wraps the `ws.ClientConnection.recv()` function. WebSockets are not streaming transports, they are + message-oriented. The Python websocket library does not support reading a specified number of bytes. The + `unused_size_bytes` parameter is listed here for consistency with `socket.recv()`. + + @param unused_size_bytes Unused. + + @return The received bytes, or NOne on timeout. + """ + try: + return self._websocket.recv(self._read_timeout_sec) + except TimeoutError as e: + # recv() raises a TimeoutError. We'll raise a socket.timeout exception instead for consistency with socket. + raise socket.timeout(str(e)) + + def __getattr__(self, item: str) -> Any: + # Defer all queries for attributes and functions that are not members of this class to self._websocket. + # __getattribute__() will handle requests for members of this class (recv(), _read_timeout_sec, etc.), and + # __getattr() will not be called. + return getattr(self._websocket, item) + + def __setattr__(self, item: str, value: Any) -> None: + # There is no __setattribute__() like there is for get. See details in __getattr__(). + if item in ('_read_timeout_sec', '_websocket'): + object.__setattr__(self, item, value) + else: + setattr(self._websocket, item, value) + + TRANSPORT_HELP_OPTIONS = """\ - - Read from stdin and/or write to stdout - [file://](PATH|-) - Read from/write to the specified file, or to stdin/stdout @@ -166,10 +221,11 @@ def write(self, data: Union[bytes, bytearray]) -> int: {TRANSPORT_HELP_OPTIONS} """ +TransportType = Union[socket.socket, serial.Serial, WebsocketTransport, FileTransport] + def create_transport(descriptor: str, timeout_sec: float = None, print_func: Callable = None, mode: str = 'both', - stdout=sys.stdout) -> \ - Union[socket.socket, serial.Serial, ws.ClientConnection, FileTransport]: + stdout=sys.stdout) -> TransportType: # File: path, '-' (stdin/stdout), empty string (stdin/stdout) if descriptor in ('', '-'): descriptor = 'file://-' @@ -251,7 +307,7 @@ def create_transport(descriptor: str, timeout_sec: float = None, print_func: Cal print_func(f'Connecting to {url}.') try: - transport = ws.connect(url, open_timeout=timeout_sec) + transport = WebsocketTransport(url, open_timeout=timeout_sec) except TimeoutError: raise TimeoutError(f'Timed out connecting to {url}.') return transport @@ -300,3 +356,39 @@ def create_transport(descriptor: str, timeout_sec: float = None, print_func: Cal return transport raise ValueError(f"Unsupported transport descriptor '{descriptor}'.") + + +def recv_from_transport(transport: TransportType, size_bytes: int) -> bytes: + '''! + @brief Helper function for reading from any type of transport. + + This function abstracts `recv()` vs `read()` calls regardless of transport type. + + @param transport The transport to read from. + @param size_bytes The maximum number of bytes to read. + + @return A `bytes` array. + ''' + try: + if isinstance(transport, (socket.socket, WebsocketTransport)): + return transport.recv(size_bytes) + else: + return transport.read(size_bytes) + except (socket.timeout, TimeoutError): + return bytes() + + +def set_read_timeout(transport: TransportType, timeout_sec: float): + if isinstance(transport, socket.socket): + if timeout_sec == 0: + transport.setblocking(False) + else: + transport.setblocking(True) + transport.settimeout(timeout_sec) + elif isinstance(transport, WebsocketTransport): + transport.set_timeout(timeout_sec) + elif isinstance(transport, serial.Serial): + transport.timeout = timeout_sec + else: + # Read timeout not applicable for files. + pass