From 294ec015aa0085c6606b145fcb48a59b874fb322 Mon Sep 17 00:00:00 2001 From: Hector Santos Date: Sat, 28 Feb 2026 12:42:31 +0100 Subject: [PATCH 1/5] feat: implement WebSocket streaming protocol for large payloads #2256 --- rust/Cargo.toml | 2 +- rust/src/client_api.rs | 3 + rust/src/client_api/browser.rs | 105 +++++++-- rust/src/client_api/regular.rs | 155 ++++++++++++- rust/src/client_api/ws_streaming.rs | 326 ++++++++++++++++++++++++++++ 5 files changed, 560 insertions(+), 31 deletions(-) create mode 100644 rust/src/client_api/ws_streaming.rs diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 29d2179..20d4e70 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "freenet-stdlib" -version = "0.1.40" +version = "0.2.0" edition = "2021" rust-version = "1.80" publish = true diff --git a/rust/src/client_api.rs b/rust/src/client_api.rs index be8f1eb..d912f44 100644 --- a/rust/src/client_api.rs +++ b/rust/src/client_api.rs @@ -22,6 +22,9 @@ mod browser; #[cfg(all(target_family = "wasm", feature = "net"))] pub use browser::*; +#[cfg(feature = "net")] +pub(crate) mod ws_streaming; + pub use client_events::*; #[cfg(feature = "net")] diff --git a/rust/src/client_api/browser.rs b/rust/src/client_api/browser.rs index d2382a1..56124bf 100644 --- a/rust/src/client_api/browser.rs +++ b/rust/src/client_api/browser.rs @@ -32,6 +32,10 @@ impl WebApi { { let eh = Rc::new(RefCell::new(error_handler.clone())); let result_handler = Rc::new(RefCell::new(result_handler)); + let reassembly = Rc::new(RefCell::new( + super::ws_streaming::ChunkReassemblyBuffer::new(), + )); + let onmessage_callback = Closure::::new(move |e: MessageEvent| { // Extract the Blob from the MessageEvent let value: JsValue = e.data(); @@ -44,6 +48,7 @@ impl WebApi { let fr_clone = file_reader.clone(); let eh_clone = eh.clone(); let result_handler_clone = result_handler.clone(); + let reassembly_clone = reassembly.clone(); let onloadend_callback = Closure::::new(move || { let array_buffer = fr_clone @@ -52,7 +57,39 @@ impl WebApi { .dyn_into::() .unwrap(); let bytes = js_sys::Uint8Array::new(&array_buffer).to_vec(); - let response: HostResult = match bincode::deserialize(&bytes) { + + use super::ws_streaming::{self, StreamMessage}; + let payload = match ws_streaming::parse_message(&bytes) { + Ok(StreamMessage::Complete(data)) => data.to_vec(), + Ok(StreamMessage::Chunk { + total_chunks, + payload: data, + }) => { + match reassembly_clone + .borrow_mut() + .receive_chunk(total_chunks, data) + { + Ok(Some(complete)) => complete, + Ok(None) => return, // more chunks needed + Err(e) => { + eh_clone.borrow_mut()(Error::ConnectionError(serde_json::json!({ + "error": format!("{e}"), + "source": "streaming reassembly" + }))); + return; + } + } + } + Err(e) => { + eh_clone.borrow_mut()(Error::ConnectionError(serde_json::json!({ + "error": format!("{e}"), + "source": "streaming parse" + }))); + return; + } + }; + + let response: HostResult = match bincode::deserialize(&payload) { Ok(val) => val, Err(err) => { eh_clone.borrow_mut()(Error::ConnectionError(serde_json::json!({ @@ -94,7 +131,6 @@ impl WebApi { handler(); } }) as Box); - // conn.add_event_listener_with_callback("open", onopen_callback.as_ref().unchecked_ref()); conn.set_onopen(Some(onopen_callback.as_ref().unchecked_ref())); onopen_callback.forget(); @@ -116,8 +152,10 @@ impl WebApi { } pub async fn send(&mut self, request: ClientRequest<'static>) -> Result<(), Error> { - // Check WebSocket ready state before sending - // Per WebSocket spec, send() silently discards data when socket is CLOSING/CLOSED + use super::ws_streaming::{self, CHUNK_THRESHOLD}; + + // Check WebSocket ready state before sending. + // Per WebSocket spec, send() silently discards data when socket is CLOSING/CLOSED. let ready_state = self.conn.ready_state(); if ready_state != web_sys::WebSocket::OPEN { let state_name = match ready_state { @@ -137,29 +175,48 @@ impl WebApi { } let send = bincode::serialize(&request)?; - self.conn.send_with_u8_array(&send).map_err(|err| { - let err: serde_json::Value = match serde_wasm_bindgen::from_value(err) { - Ok(e) => e, - Err(e) => { - let e = serde_json::json!({ - "error": format!("{e}"), - "origin": "request serialization", - "request": format!("{request:?}"), - }); - (self.error_handler)(Error::ConnectionError(e.clone())); - return Error::ConnectionError(e); - } - }; - (self.error_handler)(Error::ConnectionError(serde_json::json!({ - "error": err, - "origin": "request sending", - "request": format!("{request:?}"), - }))); - Error::ConnectionError(err) - })?; + + if send.len() > CHUNK_THRESHOLD { + let chunks = ws_streaming::chunk_payload(&send); + for chunk in &chunks { + self.conn + .send_with_u8_array(chunk) + .map_err(|err| Self::map_send_error(err, &request, &mut self.error_handler))?; + } + } else { + let wrapped = ws_streaming::wrap_complete(send); + self.conn + .send_with_u8_array(&wrapped) + .map_err(|err| Self::map_send_error(err, &request, &mut self.error_handler))?; + } Ok(()) } + fn map_send_error( + err: JsValue, + request: &ClientRequest<'_>, + error_handler: &mut Box, + ) -> Error { + let err: serde_json::Value = match serde_wasm_bindgen::from_value(err) { + Ok(e) => e, + Err(e) => { + let e = serde_json::json!({ + "error": format!("{e}"), + "origin": "request serialization", + "request": format!("{request:?}"), + }); + error_handler(Error::ConnectionError(e.clone())); + return Error::ConnectionError(e); + } + }; + error_handler(Error::ConnectionError(serde_json::json!({ + "error": err, + "origin": "request sending", + "request": format!("{request:?}"), + }))); + Error::ConnectionError(err) + } + pub fn disconnect(self, cause: impl AsRef) { let _ = self.conn.close_with_code_and_reason(1000, cause.as_ref()); } diff --git a/rust/src/client_api/regular.rs b/rust/src/client_api/regular.rs index ae00723..63f6738 100644 --- a/rust/src/client_api/regular.rs +++ b/rust/src/client_api/regular.rs @@ -145,6 +145,8 @@ async fn request_handler( mut response_tx: Sender, mut conn: Connection, ) { + let mut reassembly = super::ws_streaming::ChunkReassemblyBuffer::new(); + let error = loop { tokio::select! { req = request_rx.recv() => { @@ -154,7 +156,7 @@ async fn request_handler( } } res = conn.next() => { - match process_response(&mut conn, &mut response_tx, res).await { + match process_response(&mut conn, &mut response_tx, res, &mut reassembly).await { Ok(_) => continue, Err(err) => break err, } @@ -175,11 +177,23 @@ async fn process_request( conn: &mut Connection, req: Option>, ) -> Result<(), Error> { + use super::ws_streaming::{self, CHUNK_THRESHOLD}; + let req = req.ok_or(Error::ChannelClosed)?; let msg = bincode::serialize(&req) .map_err(Into::into) .map_err(Error::OtherError)?; - conn.send(Message::Binary(msg.into())).await?; + + if msg.len() > CHUNK_THRESHOLD { + let chunks = ws_streaming::chunk_payload(&msg); + for chunk in chunks { + conn.send(Message::Binary(chunk.into())).await?; + } + } else { + let wrapped = ws_streaming::wrap_complete(msg); + conn.send(Message::Binary(wrapped.into())).await?; + } + if let ClientRequest::Disconnect { cause } = req { conn.close(cause.map(|c| CloseFrame { code: CloseCode::Normal, @@ -199,18 +213,51 @@ async fn process_response( conn: &mut Connection, response_tx: &mut Sender, res: Option>, + reassembly: &mut super::ws_streaming::ChunkReassemblyBuffer, ) -> Result<(), Error> { + use super::ws_streaming::{self, StreamMessage}; + let res = res.ok_or(Error::ConnectionClosed)??; match res { Message::Text(msg) => { - let response: HostResult = bincode::deserialize(msg.as_bytes())?; + let bytes = match ws_streaming::parse_message(msg.as_bytes()) + .map_err(|e| Error::OtherError(e.into()))? + { + StreamMessage::Complete(payload) => payload.to_vec(), + StreamMessage::Chunk { + total_chunks, + payload, + } => match reassembly + .receive_chunk(total_chunks, payload) + .map_err(|e| Error::OtherError(e.into()))? + { + Some(complete) => complete, + None => return Ok(()), + }, + }; + let response: HostResult = bincode::deserialize(&bytes)?; response_tx .send(response) .await .map_err(|_| Error::ChannelClosed)?; } Message::Binary(binary) => { - let response: HostResult = bincode::deserialize(&binary)?; + let bytes = match ws_streaming::parse_message(&binary) + .map_err(|e| Error::OtherError(e.into()))? + { + StreamMessage::Complete(payload) => payload.to_vec(), + StreamMessage::Chunk { + total_chunks, + payload, + } => match reassembly + .receive_chunk(total_chunks, payload) + .map_err(|e| Error::OtherError(e.into()))? + { + Some(complete) => complete, + None => return Ok(()), + }, + }; + let response: HostResult = bincode::deserialize(&bytes)?; response_tx .send(response) .await @@ -253,6 +300,8 @@ mod test { self, tx: tokio::sync::oneshot::Sender<()>, ) -> Result<(), Box> { + use crate::client_api::ws_streaming; + let (stream, _) = tokio::time::timeout(Duration::from_millis(10), self.listener.accept()).await??; let mut stream = tokio_tungstenite::accept_async(stream).await?; @@ -260,7 +309,8 @@ mod test { if !self.recv { let res: HostResult = Ok(HostResponse::Ok); let req = bincode::serialize(&res)?; - stream.send(Message::Binary(req.into())).await?; + let wrapped = ws_streaming::wrap_complete(req); + stream.send(Message::Binary(wrapped.into())).await?; } let Message::Binary(msg) = stream.next().await.ok_or_else(|| "no msg".to_owned())?? @@ -268,12 +318,105 @@ mod test { return Err("wrong msg".to_owned().into()); }; - let _req: ClientRequest = bincode::deserialize(&msg)?; + // Unwrap the streaming envelope + let payload = match ws_streaming::parse_message(&msg)? { + ws_streaming::StreamMessage::Complete(data) => data.to_vec(), + ws_streaming::StreamMessage::Chunk { .. } => { + return Err("unexpected chunk in test".to_owned().into()); + } + }; + + let _req: ClientRequest = bincode::deserialize(&payload)?; tx.send(()).map_err(|_| "couldn't error".to_owned())?; Ok(()) } } + struct ChunkedServer { + listener: TcpListener, + payload_size: usize, + } + + impl ChunkedServer { + async fn new(port: u16, payload_size: usize) -> Self { + let listener = tokio::net::TcpListener::bind((Ipv4Addr::LOCALHOST, port)) + .await + .unwrap(); + ChunkedServer { + listener, + payload_size, + } + } + + async fn listen( + self, + tx: tokio::sync::oneshot::Sender<()>, + ) -> Result<(), Box> { + use crate::client_api::ws_streaming; + use crate::contract_interface::{ContractCode, ContractKey, WrappedState}; + use crate::parameters::Parameters; + + let (stream, _) = + tokio::time::timeout(Duration::from_millis(100), self.listener.accept()).await??; + let mut stream = tokio_tungstenite::accept_async(stream).await?; + + let state = WrappedState::new(vec![0xAB; self.payload_size]); + let code = ContractCode::from(vec![1, 2, 3]); + let key = ContractKey::from_params_and_code(Parameters::from(vec![]), &code); + let res: HostResult = Ok(HostResponse::ContractResponse( + crate::client_api::ContractResponse::GetResponse { + key, + contract: None, + state, + }, + )); + let serialized = bincode::serialize(&res)?; + + // Send as chunks + let chunks = ws_streaming::chunk_payload(&serialized); + assert!(chunks.len() > 1, "payload should produce multiple chunks"); + for chunk in chunks { + stream.send(Message::Binary(chunk.into())).await?; + } + + // Wait for client disconnect + let msg = tokio::time::timeout(Duration::from_millis(100), stream.next()).await; + drop(msg); + tx.send(()).map_err(|_| "signal failed".to_owned())?; + Ok(()) + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_recv_chunked() -> Result<(), Box> { + use crate::client_api::ContractResponse; + + let port = PORT.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let payload_size = 600 * 1024; // 600 KiB state → multiple chunks + let server = ChunkedServer::new(port, payload_size).await; + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + let server_result = tokio::task::spawn(server.listen(tx)); + let (ws_conn, _) = + tokio_tungstenite::connect_async(format!("ws://localhost:{port}/")).await?; + let mut client = WebApi::start(ws_conn); + + let response = client.recv().await?; + match response { + HostResponse::ContractResponse(ContractResponse::GetResponse { state, .. }) => { + assert_eq!(state.size(), payload_size); + assert!(state.as_ref().iter().all(|&b| b == 0xAB)); + } + other => panic!("expected GetResponse, got {other:?}"), + } + + client + .send(ClientRequest::Disconnect { cause: None }) + .await?; + tokio::time::timeout(Duration::from_millis(100), rx).await??; + tokio::time::timeout(Duration::from_millis(100), server_result).await???; + Ok(()) + } + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_send() -> Result<(), Box> { let port = PORT.fetch_add(1, std::sync::atomic::Ordering::SeqCst); diff --git a/rust/src/client_api/ws_streaming.rs b/rust/src/client_api/ws_streaming.rs new file mode 100644 index 0000000..d4097d7 --- /dev/null +++ b/rust/src/client_api/ws_streaming.rs @@ -0,0 +1,326 @@ +//! WebSocket message streaming protocol for large payloads (client-side). +//! +//! Protocol constants are duplicated from the server-side module to avoid +//! adding a crate dependency from stdlib to core. +//! +//! All WebSocket messages use a 1-byte type prefix: +//! - `0x00` + payload = complete message +//! - `0x01` + 4 bytes (total_chunks LE) + payload = stream chunk +//! +//! The chunk header is 5 bytes (`CHUNK_HEADER_SIZE`): the 1-byte type prefix +//! followed by a single little-endian `u32` total_chunks field. + +const MSG_COMPLETE: u8 = 0x00; +const MSG_CHUNK: u8 = 0x01; + +/// 1 (type) + 4 (total_chunks). +pub const CHUNK_HEADER_SIZE: usize = 5; + +/// Default chunk payload size: 256 KiB. +pub const DEFAULT_CHUNK_SIZE: usize = 256 * 1024; + +/// Messages larger than this threshold are chunked. +pub const CHUNK_THRESHOLD: usize = 512 * 1024; + +/// Maximum `total_chunks` accepted from the wire. +/// Based on MAX_STATE_SIZE (50 MiB) / DEFAULT_CHUNK_SIZE. +const MAX_TOTAL_CHUNKS: u32 = 256; + +/// Parsed streaming message. +#[derive(Debug)] +pub enum StreamMessage<'a> { + Complete(&'a [u8]), + Chunk { + total_chunks: u32, + payload: &'a [u8], + }, +} + +#[derive(Debug, thiserror::Error)] +pub enum StreamError { + #[error("message too short: expected at least {expected} bytes, got {actual}")] + MessageTooShort { expected: usize, actual: usize }, + #[error("unknown message type prefix: 0x{0:02x}")] + UnknownMessageType(u8), + #[error("total_chunks is zero")] + ZeroTotalChunks, + #[error("total_chunks {total_chunks} exceeds maximum {max}")] + TotalChunksTooLarge { total_chunks: u32, max: u32 }, + #[error("total_chunks mismatch (expected {expected}, got {actual})")] + TotalChunksMismatch { expected: u32, actual: u32 }, +} + +/// Wraps a serialized payload as a complete (non-chunked) streaming message. +pub fn wrap_complete(data: Vec) -> Vec { + let mut buf = Vec::with_capacity(1 + data.len()); + buf.push(MSG_COMPLETE); + buf.extend_from_slice(&data); + buf +} + +/// Splits a serialized payload into chunked streaming messages. +pub fn chunk_payload(data: &[u8]) -> Vec> { + if data.is_empty() { + let mut buf = Vec::with_capacity(CHUNK_HEADER_SIZE); + buf.push(MSG_CHUNK); + buf.extend_from_slice(&1u32.to_le_bytes()); + return vec![buf]; + } + + let total_chunks = data.len().div_ceil(DEFAULT_CHUNK_SIZE); + let mut chunks = Vec::with_capacity(total_chunks); + + for chunk_data in data.chunks(DEFAULT_CHUNK_SIZE) { + let mut buf = Vec::with_capacity(CHUNK_HEADER_SIZE + chunk_data.len()); + buf.push(MSG_CHUNK); + buf.extend_from_slice(&(total_chunks as u32).to_le_bytes()); + buf.extend_from_slice(chunk_data); + chunks.push(buf); + } + + chunks +} + +/// Parses a raw WebSocket binary message into a streaming protocol message. +pub fn parse_message(data: &[u8]) -> Result, StreamError> { + if data.is_empty() { + return Err(StreamError::MessageTooShort { + expected: 1, + actual: 0, + }); + } + + match data[0] { + MSG_COMPLETE => Ok(StreamMessage::Complete(&data[1..])), + MSG_CHUNK => { + if data.len() < CHUNK_HEADER_SIZE { + return Err(StreamError::MessageTooShort { + expected: CHUNK_HEADER_SIZE, + actual: data.len(), + }); + } + let total_chunks = u32::from_le_bytes([data[1], data[2], data[3], data[4]]); + + if total_chunks == 0 { + return Err(StreamError::ZeroTotalChunks); + } + if total_chunks > MAX_TOTAL_CHUNKS { + return Err(StreamError::TotalChunksTooLarge { + total_chunks, + max: MAX_TOTAL_CHUNKS, + }); + } + + Ok(StreamMessage::Chunk { + total_chunks, + payload: &data[CHUNK_HEADER_SIZE..], + }) + } + other => Err(StreamError::UnknownMessageType(other)), + } +} + +/// Sequential reassembly buffer for chunked streams. +/// +/// TCP guarantees ordered delivery and the select loop serializes message sends, +/// so chunks always arrive in order. This buffer simply appends incoming chunks +/// and returns the complete payload when all arrive. +pub struct ChunkReassemblyBuffer { + data: Vec, + total_chunks: u32, + received: u32, +} + +impl ChunkReassemblyBuffer { + pub fn new() -> Self { + Self { + data: Vec::new(), + total_chunks: 0, + received: 0, + } + } + + /// Receives a chunk and returns the fully reassembled payload when all chunks arrive. + /// + /// Returns `Ok(None)` if more chunks are needed. + pub fn receive_chunk( + &mut self, + total_chunks: u32, + payload: &[u8], + ) -> Result>, StreamError> { + if self.received == 0 { + self.total_chunks = total_chunks; + self.data + .reserve(total_chunks as usize * DEFAULT_CHUNK_SIZE); + } else if self.total_chunks != total_chunks { + return Err(StreamError::TotalChunksMismatch { + expected: self.total_chunks, + actual: total_chunks, + }); + } + + self.data.extend_from_slice(payload); + self.received += 1; + + if self.received == self.total_chunks { + let result = std::mem::take(&mut self.data); + self.received = 0; + self.total_chunks = 0; + Ok(Some(result)) + } else { + Ok(None) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn wrap_complete_roundtrip() { + let data = vec![1, 2, 3, 4, 5]; + let wrapped = wrap_complete(data.clone()); + assert_eq!(wrapped[0], MSG_COMPLETE); + match parse_message(&wrapped).unwrap() { + StreamMessage::Complete(payload) => assert_eq!(payload, &data[..]), + StreamMessage::Chunk { .. } => panic!("expected Complete"), + } + } + + #[test] + fn chunk_small_payload_roundtrip() { + let data = vec![42u8; 1024]; + let chunks = chunk_payload(&data); + assert_eq!(chunks.len(), 1); + + match parse_message(&chunks[0]).unwrap() { + StreamMessage::Chunk { + total_chunks, + payload, + } => { + assert_eq!(total_chunks, 1); + assert_eq!(payload, &data[..]); + } + StreamMessage::Complete(_) => panic!("expected Chunk"), + } + } + + #[test] + fn chunk_large_payload_roundtrip() { + let data: Vec = (0..600 * 1024).map(|i| (i % 256) as u8).collect(); + let chunks = chunk_payload(&data); + assert_eq!(chunks.len(), 3); + + let mut reassembly = ChunkReassemblyBuffer::new(); + for (i, chunk) in chunks.iter().enumerate() { + match parse_message(chunk).unwrap() { + StreamMessage::Chunk { + total_chunks, + payload, + } => { + let result = reassembly.receive_chunk(total_chunks, payload).unwrap(); + if i < 2 { + assert!(result.is_none()); + } else { + assert_eq!(result.unwrap(), data); + } + } + StreamMessage::Complete(_) => panic!("expected Chunk"), + } + } + } + + #[test] + fn chunk_empty_payload() { + let chunks = chunk_payload(&[]); + assert_eq!(chunks.len(), 1); + + match parse_message(&chunks[0]).unwrap() { + StreamMessage::Chunk { + total_chunks, + payload, + } => { + assert_eq!(total_chunks, 1); + assert!(payload.is_empty()); + + let mut reassembly = ChunkReassemblyBuffer::new(); + let result = reassembly.receive_chunk(total_chunks, payload).unwrap(); + assert_eq!(result.unwrap(), Vec::::new()); + } + StreamMessage::Complete(_) => panic!("expected Chunk"), + } + } + + #[test] + fn parse_errors() { + assert!(matches!( + parse_message(&[]).unwrap_err(), + StreamError::MessageTooShort { .. } + )); + assert!(matches!( + parse_message(&[0xFF, 1, 2, 3]).unwrap_err(), + StreamError::UnknownMessageType(0xFF) + )); + assert!(matches!( + parse_message(&[MSG_CHUNK, 0, 0]).unwrap_err(), + StreamError::MessageTooShort { .. } + )); + + let mut zero_chunks = vec![MSG_CHUNK]; + zero_chunks.extend_from_slice(&0u32.to_le_bytes()); + assert!(matches!( + parse_message(&zero_chunks).unwrap_err(), + StreamError::ZeroTotalChunks + )); + + let mut too_large = vec![MSG_CHUNK]; + too_large.extend_from_slice(&1000u32.to_le_bytes()); + assert!(matches!( + parse_message(&too_large).unwrap_err(), + StreamError::TotalChunksTooLarge { .. } + )); + } + + #[test] + fn total_chunks_mismatch() { + let mut reassembly = ChunkReassemblyBuffer::new(); + reassembly.receive_chunk(3, &[1, 2, 3]).unwrap(); + assert!(matches!( + reassembly.receive_chunk(5, &[4, 5, 6]).unwrap_err(), + StreamError::TotalChunksMismatch { .. } + )); + } + + #[test] + fn reassembly_resets_after_completion() { + let data_a = vec![0xAA; DEFAULT_CHUNK_SIZE * 2]; + let data_b = vec![0xBB; DEFAULT_CHUNK_SIZE * 3]; + + let mut reassembly = ChunkReassemblyBuffer::new(); + + for chunk in &chunk_payload(&data_a) { + if let StreamMessage::Chunk { + total_chunks, + payload, + } = parse_message(chunk).unwrap() + { + if let Some(r) = reassembly.receive_chunk(total_chunks, payload).unwrap() { + assert_eq!(r, data_a); + } + } + } + + for chunk in &chunk_payload(&data_b) { + if let StreamMessage::Chunk { + total_chunks, + payload, + } = parse_message(chunk).unwrap() + { + if let Some(r) = reassembly.receive_chunk(total_chunks, payload).unwrap() { + assert_eq!(r, data_b); + } + } + } + } +} From 9d839f0d6dedafa20ceb88c958317b71de5e28e8 Mon Sep 17 00:00:00 2001 From: Hector Santos Date: Sun, 1 Mar 2026 18:45:02 +0100 Subject: [PATCH 2/5] feat: add streaming support for WebSocket messages with chunking --- rust/src/client_api.rs | 2 +- rust/src/client_api/browser.rs | 87 +-- rust/src/client_api/client_events.rs | 112 +++- rust/src/client_api/regular.rs | 435 ++++++++++---- rust/src/client_api/streaming.rs | 540 ++++++++++++++++++ rust/src/client_api/ws_streaming.rs | 326 ----------- .../src/generated/client_request_generated.rs | 224 +++++++- rust/src/generated/host_response_generated.rs | 224 +++++++- schemas/flatbuffers/client_request.fbs | 10 +- schemas/flatbuffers/host_response.fbs | 10 +- 10 files changed, 1478 insertions(+), 492 deletions(-) create mode 100644 rust/src/client_api/streaming.rs delete mode 100644 rust/src/client_api/ws_streaming.rs diff --git a/rust/src/client_api.rs b/rust/src/client_api.rs index d912f44..42ccc29 100644 --- a/rust/src/client_api.rs +++ b/rust/src/client_api.rs @@ -23,7 +23,7 @@ mod browser; pub use browser::*; #[cfg(feature = "net")] -pub(crate) mod ws_streaming; +pub mod streaming; pub use client_events::*; diff --git a/rust/src/client_api/browser.rs b/rust/src/client_api/browser.rs index 56124bf..5afd7ff 100644 --- a/rust/src/client_api/browser.rs +++ b/rust/src/client_api/browser.rs @@ -32,9 +32,7 @@ impl WebApi { { let eh = Rc::new(RefCell::new(error_handler.clone())); let result_handler = Rc::new(RefCell::new(result_handler)); - let reassembly = Rc::new(RefCell::new( - super::ws_streaming::ChunkReassemblyBuffer::new(), - )); + let reassembly = Rc::new(RefCell::new(super::streaming::ReassemblyBuffer::new())); let onmessage_callback = Closure::::new(move |e: MessageEvent| { // Extract the Blob from the MessageEvent @@ -58,18 +56,51 @@ impl WebApi { .unwrap(); let bytes = js_sys::Uint8Array::new(&array_buffer).to_vec(); - use super::ws_streaming::{self, StreamMessage}; - let payload = match ws_streaming::parse_message(&bytes) { - Ok(StreamMessage::Complete(data)) => data.to_vec(), - Ok(StreamMessage::Chunk { - total_chunks, - payload: data, + use super::client_events::HostResponse; + + let response: HostResult = match bincode::deserialize(&bytes) { + Ok(val) => val, + Err(err) => { + eh_clone.borrow_mut()(Error::ConnectionError(serde_json::json!({ + "error": format!("{err}"), + "source": "host response deserialization" + }))); + return; + } + }; + + match response { + Ok(HostResponse::StreamHeader { .. }) => { + // StreamHeader is metadata only — the following StreamChunks + // will be reassembled transparently by the ReassemblyBuffer. + // Browser incremental streaming is not yet supported. + return; + } + Ok(HostResponse::StreamChunk { + stream_id, + index, + total, + data, }) => { match reassembly_clone .borrow_mut() - .receive_chunk(total_chunks, data) + .receive_chunk(stream_id, index, total, &data) { - Ok(Some(complete)) => complete, + Ok(Some(complete)) => { + let inner: HostResult = match bincode::deserialize(&complete) { + Ok(val) => val, + Err(err) => { + eh_clone.borrow_mut()(Error::ConnectionError( + serde_json::json!({ + "error": format!("{err}"), + "source": "stream reassembly deserialization" + }), + )); + return; + } + }; + result_handler_clone.borrow_mut()(inner); + } Ok(None) => return, // more chunks needed Err(e) => { eh_clone.borrow_mut()(Error::ConnectionError(serde_json::json!({ @@ -80,26 +111,10 @@ impl WebApi { } } } - Err(e) => { - eh_clone.borrow_mut()(Error::ConnectionError(serde_json::json!({ - "error": format!("{e}"), - "source": "streaming parse" - }))); - return; - } - }; - - let response: HostResult = match bincode::deserialize(&payload) { - Ok(val) => val, - Err(err) => { - eh_clone.borrow_mut()(Error::ConnectionError(serde_json::json!({ - "error": format!("{err}"), - "source": "host response deserialization" - }))); - return; + other => { + result_handler_clone.borrow_mut()(other); } - }; - result_handler_clone.borrow_mut()(response); + } }); // Set the FileReader handlers @@ -152,7 +167,7 @@ impl WebApi { } pub async fn send(&mut self, request: ClientRequest<'static>) -> Result<(), Error> { - use super::ws_streaming::{self, CHUNK_THRESHOLD}; + use super::streaming::{chunk_request, CHUNK_THRESHOLD}; // Check WebSocket ready state before sending. // Per WebSocket spec, send() silently discards data when socket is CLOSING/CLOSED. @@ -177,16 +192,18 @@ impl WebApi { let send = bincode::serialize(&request)?; if send.len() > CHUNK_THRESHOLD { - let chunks = ws_streaming::chunk_payload(&send); + let stream_id = 0; // browser client uses single stream + let chunks = chunk_request(send, stream_id); for chunk in &chunks { + let chunk_bytes = + bincode::serialize(chunk).map_err(|e| Error::OtherError(e.into()))?; self.conn - .send_with_u8_array(chunk) + .send_with_u8_array(&chunk_bytes) .map_err(|err| Self::map_send_error(err, &request, &mut self.error_handler))?; } } else { - let wrapped = ws_streaming::wrap_complete(send); self.conn - .send_with_u8_array(&wrapped) + .send_with_u8_array(&send) .map_err(|err| Self::map_send_error(err, &request, &mut self.error_handler))?; } Ok(()) diff --git a/rust/src/client_api/client_events.rs b/rust/src/client_api/client_events.rs index 920c249..30481fb 100644 --- a/rust/src/client_api/client_events.rs +++ b/rust/src/client_api/client_events.rs @@ -32,6 +32,7 @@ use crate::generated::host_response::{ NotFoundArgs, Ok as FbsOk, OkArgs, OutboundDelegateMsg as FbsOutboundDelegateMsg, OutboundDelegateMsgArgs, OutboundDelegateMsgType, PutResponse as FbsPutResponse, PutResponseArgs, RequestUserInput as FbsRequestUserInput, RequestUserInputArgs, + StreamChunk as FbsHostStreamChunk, StreamChunkArgs as FbsHostStreamChunkArgs, UpdateNotification as FbsUpdateNotification, UpdateNotificationArgs, UpdateResponse as FbsUpdateResponse, UpdateResponseArgs, }; @@ -257,6 +258,13 @@ pub enum ClientRequest<'a> { NodeQueries(NodeQuery), /// Gracefully disconnect from the host. Close, + /// A chunk of a larger streamed message. + StreamChunk { + stream_id: u32, + index: u32, + total: u32, + data: Vec, + }, } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -319,6 +327,17 @@ impl ClientRequest<'_> { ClientRequest::Authenticate { token } => ClientRequest::Authenticate { token }, ClientRequest::NodeQueries(query) => ClientRequest::NodeQueries(query), ClientRequest::Close => ClientRequest::Close, + ClientRequest::StreamChunk { + stream_id, + index, + total, + data, + } => ClientRequest::StreamChunk { + stream_id, + index, + total, + data, + }, } } @@ -355,7 +374,20 @@ impl ClientRequest<'_> { token: token.to_owned(), } } - _ => unreachable!(), + ClientRequestType::StreamChunk => { + let chunk = client_request.client_request_as_stream_chunk().unwrap(); + ClientRequest::StreamChunk { + stream_id: chunk.stream_id(), + index: chunk.index(), + total: chunk.total(), + data: chunk.data().bytes().to_vec(), + } + } + _ => { + return Err(WsApiError::deserialization( + "unknown client request type".to_string(), + )) + } }, Err(e) => { let cause = format!("{e}"); @@ -641,6 +673,12 @@ impl Display for ClientRequest<'_> { ClientRequest::Authenticate { .. } => write!(f, "authenticate"), ClientRequest::NodeQueries(query) => write!(f, "node queries: {:?}", query), ClientRequest::Close => write!(f, "close"), + ClientRequest::StreamChunk { + stream_id, + index, + total, + .. + } => write!(f, "stream chunk {index}/{total} (stream {stream_id})"), } } } @@ -704,6 +742,33 @@ pub enum HostResponse { QueryResponse(QueryResponse), /// A requested action which doesn't require an answer was performed successfully. Ok, + /// A chunk of a larger streamed response. + StreamChunk { + stream_id: u32, + index: u32, + total: u32, + data: Vec, + }, + /// Header message announcing the start of a streamed response. + /// Sent before the corresponding [`StreamChunk`] messages so the client + /// can set up incremental consumption via [`WsStreamHandle`]. + StreamHeader { + stream_id: u32, + total_bytes: u64, + content: StreamContent, + }, +} + +/// Describes what kind of response is being streamed. +#[derive(Debug, Serialize, Deserialize, Clone)] +pub enum StreamContent { + /// A streamed GetResponse — the large state is delivered via StreamChunks. + GetResponse { + key: ContractKey, + includes_contract: bool, + }, + /// Raw binary stream (future use). + Raw, } type Peer = String; @@ -1513,6 +1578,40 @@ impl HostResponse { Ok(builder.finished_data().to_vec()) } HostResponse::QueryResponse(_) => unimplemented!(), + HostResponse::StreamChunk { + stream_id, + index, + total, + data, + } => { + let data_offset = builder.create_vector(&data); + let chunk_offset = FbsHostStreamChunk::create( + &mut builder, + &FbsHostStreamChunkArgs { + stream_id, + index, + total, + data: Some(data_offset), + }, + ); + let host_response_offset = FbsHostResponse::create( + &mut builder, + &HostResponseArgs { + response_type: HostResponseType::StreamChunk, + response: Some(chunk_offset.as_union_value()), + }, + ); + finish_host_response_buffer(&mut builder, host_response_offset); + Ok(builder.finished_data().to_vec()) + } + HostResponse::StreamHeader { .. } => { + // StreamHeader is only sent over bincode (Native encoding) to + // streaming-capable clients. Flatbuffers clients use transparent + // reassembly via StreamChunk only. + Err(Box::new(ClientError::from(ErrorKind::Unhandled { + cause: "StreamHeader is not supported over flatbuffers encoding".into(), + }))) + } } } } @@ -1543,6 +1642,17 @@ impl Display for HostResponse { HostResponse::DelegateResponse { .. } => write!(f, "delegate responses"), HostResponse::Ok => write!(f, "ok response"), HostResponse::QueryResponse(_) => write!(f, "query response"), + HostResponse::StreamChunk { + stream_id, + index, + total, + .. + } => write!(f, "stream chunk {index}/{total} (stream {stream_id})"), + HostResponse::StreamHeader { + stream_id, + total_bytes, + .. + } => write!(f, "stream header (stream {stream_id}, {total_bytes} bytes)"), } } } diff --git a/rust/src/client_api/regular.rs b/rust/src/client_api/regular.rs index 63f6738..bbaf10f 100644 --- a/rust/src/client_api/regular.rs +++ b/rust/src/client_api/regular.rs @@ -1,7 +1,8 @@ -use std::{borrow::Cow, task::Poll}; +use std::{borrow::Cow, collections::HashMap, task::Poll}; use super::{ - client_events::{ClientError, ClientRequest, ErrorKind}, + client_events::{ClientError, ClientRequest, ErrorKind, HostResponse}, + streaming::WsStreamHandle, Error, HostResult, }; use futures::{pin_mut, FutureExt, Sink, SinkExt, Stream, StreamExt}; @@ -22,6 +23,7 @@ type Connection = WebSocketStream>; pub struct WebApi { request_tx: Sender>, response_rx: Receiver, + stream_rx: Receiver, queue: Vec>, } @@ -106,10 +108,17 @@ impl WebApi { pub fn start(connection: Connection) -> Self { let (request_tx, request_rx) = mpsc::channel(1); let (response_tx, response_rx) = mpsc::channel(1); - tokio::spawn(request_handler(request_rx, response_tx, connection)); + let (stream_tx, stream_rx) = mpsc::channel(4); + tokio::spawn(request_handler( + request_rx, + response_tx, + stream_tx, + connection, + )); Self { request_tx, response_rx, + stream_rx, queue: vec![], } } @@ -124,9 +133,41 @@ impl WebApi { Ok(()) } + /// Receive the next host response. + /// + /// If the server sends a streamed response (StreamHeader + StreamChunks), + /// this method transparently reassembles the full payload and returns the + /// complete [`HostResponse`] — the caller does not need to handle streaming. + /// + /// For incremental consumption, use [`recv_stream()`](Self::recv_stream) instead. pub async fn recv(&mut self) -> HostResult { - let res = self.response_rx.recv().await; - res.ok_or_else(|| ClientError::from(ErrorKind::ChannelClosed))? + tokio::select! { + res = self.response_rx.recv() => { + res.ok_or_else(|| ClientError::from(ErrorKind::ChannelClosed))? + } + handle = self.stream_rx.recv() => { + let handle = handle.ok_or_else(|| ClientError::from(ErrorKind::ChannelClosed))?; + let complete = handle + .assemble() + .await + .map_err(|e| ClientError::from(format!("{e}")))?; + let inner: HostResult = bincode::deserialize(&complete) + .map_err(|e| ClientError::from(format!("{e}")))?; + inner + } + } + } + + /// Receive the next streamed response as a [`WsStreamHandle`]. + /// + /// Returns a handle for incremental consumption of a streamed response. + /// Use [`WsStreamHandle::into_stream()`] for chunk-by-chunk processing or + /// [`WsStreamHandle::assemble()`] to wait for the complete payload. + /// + /// Only returns when the server sends a `StreamHeader`; non-streamed + /// responses are delivered through [`recv()`](Self::recv). + pub async fn recv_stream(&mut self) -> Result { + self.stream_rx.recv().await.ok_or(Error::ChannelClosed) } #[doc(hidden)] @@ -143,20 +184,30 @@ impl WebApi { async fn request_handler( mut request_rx: Receiver>, mut response_tx: Sender, + stream_tx: Sender, mut conn: Connection, ) { - let mut reassembly = super::ws_streaming::ChunkReassemblyBuffer::new(); + let mut reassembly = super::streaming::ReassemblyBuffer::new(); + let mut stream_senders: HashMap = HashMap::new(); + let mut next_stream_id: u32 = 0; let error = loop { tokio::select! { req = request_rx.recv() => { - match process_request(&mut conn, req).await { + match process_request(&mut conn, req, &mut next_stream_id).await { Ok(_) => continue, Err(err) => break err, } } res = conn.next() => { - match process_response(&mut conn, &mut response_tx, res, &mut reassembly).await { + match process_response( + &mut conn, + &mut response_tx, + &stream_tx, + &mut stream_senders, + res, + &mut reassembly, + ).await { Ok(_) => continue, Err(err) => break err, } @@ -172,12 +223,12 @@ async fn request_handler( let _ = response_tx.send(Err(error)).await; } -#[inline] async fn process_request( conn: &mut Connection, req: Option>, + next_stream_id: &mut u32, ) -> Result<(), Error> { - use super::ws_streaming::{self, CHUNK_THRESHOLD}; + use super::streaming::{chunk_request, CHUNK_THRESHOLD}; let req = req.ok_or(Error::ChannelClosed)?; let msg = bincode::serialize(&req) @@ -185,13 +236,17 @@ async fn process_request( .map_err(Error::OtherError)?; if msg.len() > CHUNK_THRESHOLD { - let chunks = ws_streaming::chunk_payload(&msg); + let stream_id = *next_stream_id; + *next_stream_id = next_stream_id.wrapping_add(1); + let chunks = chunk_request(msg, stream_id); for chunk in chunks { - conn.send(Message::Binary(chunk.into())).await?; + let chunk_bytes = bincode::serialize(&chunk) + .map_err(Into::into) + .map_err(Error::OtherError)?; + conn.send(Message::Binary(chunk_bytes.into())).await?; } } else { - let wrapped = ws_streaming::wrap_complete(msg); - conn.send(Message::Binary(wrapped.into())).await?; + conn.send(Message::Binary(msg.into())).await?; } if let ClientRequest::Disconnect { cause } = req { @@ -208,69 +263,114 @@ async fn process_request( Ok(()) } -#[inline] async fn process_response( conn: &mut Connection, response_tx: &mut Sender, + stream_tx: &Sender, + stream_senders: &mut HashMap, res: Option>, - reassembly: &mut super::ws_streaming::ChunkReassemblyBuffer, + reassembly: &mut super::streaming::ReassemblyBuffer, ) -> Result<(), Error> { - use super::ws_streaming::{self, StreamMessage}; - let res = res.ok_or(Error::ConnectionClosed)??; match res { - Message::Text(msg) => { - let bytes = match ws_streaming::parse_message(msg.as_bytes()) - .map_err(|e| Error::OtherError(e.into()))? - { - StreamMessage::Complete(payload) => payload.to_vec(), - StreamMessage::Chunk { - total_chunks, - payload, - } => match reassembly - .receive_chunk(total_chunks, payload) - .map_err(|e| Error::OtherError(e.into()))? - { - Some(complete) => complete, - None => return Ok(()), - }, - }; - let response: HostResult = bincode::deserialize(&bytes)?; - response_tx - .send(response) + Message::Binary(binary) => { + handle_response_payload(&binary, response_tx, stream_tx, stream_senders, reassembly) + .await + } + Message::Text(text) => { + handle_response_payload( + text.as_bytes(), + response_tx, + stream_tx, + stream_senders, + reassembly, + ) + .await + } + Message::Ping(ping) => { + conn.send(Message::Pong(ping)).await?; + Ok(()) + } + Message::Pong(_) => Ok(()), + Message::Close(_) => Err(Error::ConnectionClosed), + _ => Ok(()), + } +} + +async fn handle_response_payload( + bytes: &[u8], + response_tx: &mut Sender, + stream_tx: &Sender, + stream_senders: &mut HashMap, + reassembly: &mut super::streaming::ReassemblyBuffer, +) -> Result<(), Error> { + let response: HostResult = bincode::deserialize(bytes)?; + match response { + Ok(HostResponse::StreamHeader { + stream_id, + total_bytes, + content, + }) => { + // Cap open streams to prevent unbounded growth from abandoned streams + if stream_senders.len() >= super::streaming::MAX_CONCURRENT_STREAMS { + tracing::warn!("too many open stream senders, evicting one"); + if let Some(&id) = stream_senders.keys().next() { + stream_senders.remove(&id); + } + } + let (handle, sender) = super::streaming::ws_stream_pair(content, total_bytes); + stream_senders.insert(stream_id, sender); + stream_tx + .send(handle) .await .map_err(|_| Error::ChannelClosed)?; + Ok(()) } - Message::Binary(binary) => { - let bytes = match ws_streaming::parse_message(&binary) - .map_err(|e| Error::OtherError(e.into()))? - { - StreamMessage::Complete(payload) => payload.to_vec(), - StreamMessage::Chunk { - total_chunks, - payload, - } => match reassembly - .receive_chunk(total_chunks, payload) + Ok(HostResponse::StreamChunk { + stream_id, + index, + total, + data, + }) => { + // If we have a sender for this stream_id, it was preceded by a StreamHeader + // → route chunks to the WsStreamSender for app-level streaming. + if let Some(sender) = stream_senders.get(&stream_id) { + if let Err(e) = sender.send_chunk(data) { + tracing::warn!(stream_id, "stream chunk send failed: {e}"); + stream_senders.remove(&stream_id); + return Ok(()); + } + // Drop sender on last chunk so the handle's rx closes + if index + 1 == total { + stream_senders.remove(&stream_id); + } + Ok(()) + } else { + // No StreamHeader seen → transparent reassembly (backward compat) + match reassembly + .receive_chunk(stream_id, index, total, &data) .map_err(|e| Error::OtherError(e.into()))? { - Some(complete) => complete, - None => return Ok(()), - }, - }; - let response: HostResult = bincode::deserialize(&bytes)?; + Some(complete) => { + let inner: HostResult = bincode::deserialize(&complete)?; + response_tx + .send(inner) + .await + .map_err(|_| Error::ChannelClosed)?; + Ok(()) + } + None => Ok(()), + } + } + } + other => { response_tx - .send(response) + .send(other) .await .map_err(|_| Error::ChannelClosed)?; + Ok(()) } - Message::Ping(ping) => { - conn.send(Message::Pong(ping)).await?; - } - Message::Pong(_) => {} - Message::Close(_) => return Err(Error::ConnectionClosed), - _ => {} } - Ok(()) } #[cfg(test)] @@ -300,17 +400,14 @@ mod test { self, tx: tokio::sync::oneshot::Sender<()>, ) -> Result<(), Box> { - use crate::client_api::ws_streaming; - let (stream, _) = tokio::time::timeout(Duration::from_millis(10), self.listener.accept()).await??; let mut stream = tokio_tungstenite::accept_async(stream).await?; if !self.recv { let res: HostResult = Ok(HostResponse::Ok); - let req = bincode::serialize(&res)?; - let wrapped = ws_streaming::wrap_complete(req); - stream.send(Message::Binary(wrapped.into())).await?; + let bytes = bincode::serialize(&res)?; + stream.send(Message::Binary(bytes.into())).await?; } let Message::Binary(msg) = stream.next().await.ok_or_else(|| "no msg".to_owned())?? @@ -318,93 +415,189 @@ mod test { return Err("wrong msg".to_owned().into()); }; - // Unwrap the streaming envelope - let payload = match ws_streaming::parse_message(&msg)? { - ws_streaming::StreamMessage::Complete(data) => data.to_vec(), - ws_streaming::StreamMessage::Chunk { .. } => { - return Err("unexpected chunk in test".to_owned().into()); - } - }; - - let _req: ClientRequest = bincode::deserialize(&payload)?; + let _req: ClientRequest = bincode::deserialize(&msg)?; tx.send(()).map_err(|_| "couldn't error".to_owned())?; Ok(()) } } - struct ChunkedServer { + /// Build a serialized GetResponse payload of the given size and fill byte. + fn build_test_payload( + payload_size: usize, + fill: u8, + ) -> (Vec, crate::contract_interface::ContractKey) { + use crate::contract_interface::{ContractCode, ContractKey, WrappedState}; + use crate::parameters::Parameters; + + let state = WrappedState::new(vec![fill; payload_size]); + let code = ContractCode::from(vec![1, 2, 3]); + let key = ContractKey::from_params_and_code(Parameters::from(vec![]), &code); + let res: HostResult = Ok(HostResponse::ContractResponse( + crate::client_api::ContractResponse::GetResponse { + key, + contract: None, + state, + }, + )); + (bincode::serialize(&res).unwrap(), key) + } + + /// Accept a WS connection and send chunks (optionally preceded by a StreamHeader). + async fn serve_chunked_response( listener: TcpListener, payload_size: usize, + fill: u8, + send_header: bool, + tx: tokio::sync::oneshot::Sender<()>, + ) -> Result<(), Box> { + use crate::client_api::streaming; + + let (tcp_stream, _) = + tokio::time::timeout(Duration::from_millis(100), listener.accept()).await??; + let mut stream = tokio_tungstenite::accept_async(tcp_stream).await?; + + let (serialized, key) = build_test_payload(payload_size, fill); + let stream_id = 0u32; + + if send_header { + use crate::client_api::client_events::StreamContent; + let header: HostResult = Ok(HostResponse::StreamHeader { + stream_id, + total_bytes: serialized.len() as u64, + content: StreamContent::GetResponse { + key, + includes_contract: false, + }, + }); + let header_bytes = bincode::serialize(&header)?; + stream.send(Message::Binary(header_bytes.into())).await?; + } + + let chunks = streaming::chunk_response(serialized, stream_id); + assert!(chunks.len() > 1, "payload should produce multiple chunks"); + for chunk in chunks { + let chunk_result: HostResult = Ok(chunk); + let chunk_bytes = bincode::serialize(&chunk_result)?; + stream.send(Message::Binary(chunk_bytes.into())).await?; + } + + let msg = tokio::time::timeout(Duration::from_millis(100), stream.next()).await; + drop(msg); + tx.send(()).map_err(|_| "signal failed".to_owned())?; + Ok(()) } - impl ChunkedServer { - async fn new(port: u16, payload_size: usize) -> Self { - let listener = tokio::net::TcpListener::bind((Ipv4Addr::LOCALHOST, port)) - .await - .unwrap(); - ChunkedServer { - listener, - payload_size, + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_recv_chunked() -> Result<(), Box> { + use crate::client_api::ContractResponse; + + let port = PORT.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let payload_size = 600 * 1024; + let listener = tokio::net::TcpListener::bind((Ipv4Addr::LOCALHOST, port)) + .await + .unwrap(); + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + let server_result = tokio::task::spawn(serve_chunked_response( + listener, + payload_size, + 0xAB, + false, + tx, + )); + let (ws_conn, _) = + tokio_tungstenite::connect_async(format!("ws://localhost:{port}/")).await?; + let mut client = WebApi::start(ws_conn); + + let response = client.recv().await?; + match response { + HostResponse::ContractResponse(ContractResponse::GetResponse { state, .. }) => { + assert_eq!(state.size(), payload_size); + assert!(state.as_ref().iter().all(|&b| b == 0xAB)); } + other => panic!("expected GetResponse, got {other:?}"), } - async fn listen( - self, - tx: tokio::sync::oneshot::Sender<()>, - ) -> Result<(), Box> { - use crate::client_api::ws_streaming; - use crate::contract_interface::{ContractCode, ContractKey, WrappedState}; - use crate::parameters::Parameters; + client + .send(ClientRequest::Disconnect { cause: None }) + .await?; + tokio::time::timeout(Duration::from_millis(100), rx).await??; + tokio::time::timeout(Duration::from_millis(100), server_result).await???; + Ok(()) + } - let (stream, _) = - tokio::time::timeout(Duration::from_millis(100), self.listener.accept()).await??; - let mut stream = tokio_tungstenite::accept_async(stream).await?; + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_recv_stream_header() -> Result<(), Box> { + use crate::client_api::ContractResponse; - let state = WrappedState::new(vec![0xAB; self.payload_size]); - let code = ContractCode::from(vec![1, 2, 3]); - let key = ContractKey::from_params_and_code(Parameters::from(vec![]), &code); - let res: HostResult = Ok(HostResponse::ContractResponse( - crate::client_api::ContractResponse::GetResponse { - key, - contract: None, - state, - }, - )); - let serialized = bincode::serialize(&res)?; - - // Send as chunks - let chunks = ws_streaming::chunk_payload(&serialized); - assert!(chunks.len() > 1, "payload should produce multiple chunks"); - for chunk in chunks { - stream.send(Message::Binary(chunk.into())).await?; - } + let port = PORT.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let payload_size = 600 * 1024; + let listener = tokio::net::TcpListener::bind((Ipv4Addr::LOCALHOST, port)) + .await + .unwrap(); + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + let server_result = tokio::task::spawn(serve_chunked_response( + listener, + payload_size, + 0xCD, + true, + tx, + )); + let (ws_conn, _) = + tokio_tungstenite::connect_async(format!("ws://localhost:{port}/")).await?; + let mut client = WebApi::start(ws_conn); - // Wait for client disconnect - let msg = tokio::time::timeout(Duration::from_millis(100), stream.next()).await; - drop(msg); - tx.send(()).map_err(|_| "signal failed".to_owned())?; - Ok(()) + // Use recv_stream() to get the handle + let handle = client.recv_stream().await.unwrap(); + assert!(handle.total_bytes() >= payload_size as u64); + + // Assemble and verify + let complete = handle.assemble().await.unwrap(); + let inner: HostResult = bincode::deserialize(&complete)?; + match inner? { + HostResponse::ContractResponse(ContractResponse::GetResponse { state, .. }) => { + assert_eq!(state.size(), payload_size); + assert!(state.as_ref().iter().all(|&b| b == 0xCD)); + } + other => panic!("expected GetResponse, got {other:?}"), } + + client + .send(ClientRequest::Disconnect { cause: None }) + .await?; + tokio::time::timeout(Duration::from_millis(100), rx).await??; + tokio::time::timeout(Duration::from_millis(100), server_result).await???; + Ok(()) } + /// Tests that recv() transparently assembles StreamHeader+StreamChunk flows. #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_recv_chunked() -> Result<(), Box> { + async fn test_recv_transparent_stream_header( + ) -> Result<(), Box> { use crate::client_api::ContractResponse; let port = PORT.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - let payload_size = 600 * 1024; // 600 KiB state → multiple chunks - let server = ChunkedServer::new(port, payload_size).await; + let payload_size = 600 * 1024; + let listener = tokio::net::TcpListener::bind((Ipv4Addr::LOCALHOST, port)) + .await + .unwrap(); let (tx, rx) = tokio::sync::oneshot::channel::<()>(); - let server_result = tokio::task::spawn(server.listen(tx)); + let server_result = tokio::task::spawn(serve_chunked_response( + listener, + payload_size, + 0xCD, + true, + tx, + )); let (ws_conn, _) = tokio_tungstenite::connect_async(format!("ws://localhost:{port}/")).await?; let mut client = WebApi::start(ws_conn); + // Use recv() which should auto-assemble the stream let response = client.recv().await?; match response { HostResponse::ContractResponse(ContractResponse::GetResponse { state, .. }) => { assert_eq!(state.size(), payload_size); - assert!(state.as_ref().iter().all(|&b| b == 0xAB)); + assert!(state.as_ref().iter().all(|&b| b == 0xCD)); } other => panic!("expected GetResponse, got {other:?}"), } diff --git a/rust/src/client_api/streaming.rs b/rust/src/client_api/streaming.rs new file mode 100644 index 0000000..013c875 --- /dev/null +++ b/rust/src/client_api/streaming.rs @@ -0,0 +1,540 @@ +//! Chunking and reassembly helpers for WebSocket message streaming. +//! +//! Large serialized payloads are split into [`ClientRequest::StreamChunk`] or +//! [`HostResponse::StreamChunk`] variants. Each chunk carries a `stream_id` so +//! multiple streams can be reassembled concurrently. + +use std::collections::HashMap; + +use super::{ClientRequest, HostResponse}; + +/// Default chunk payload size: 256 KiB. +pub const CHUNK_SIZE: usize = 256 * 1024; + +/// Messages larger than this threshold are chunked. +pub const CHUNK_THRESHOLD: usize = 512 * 1024; + +/// Maximum `total_chunks` accepted from the wire. +/// 256 chunks * 256 KiB = 64 MiB, enough headroom for MAX_STATE_SIZE (50 MiB) +/// plus serialization overhead. +pub const MAX_TOTAL_CHUNKS: u32 = 256; + +/// Maximum concurrent streams in a single `ReassemblyBuffer`. +pub const MAX_CONCURRENT_STREAMS: usize = 8; + +/// Chunks to send before yielding to the tokio runtime. +pub const MAX_CHUNKS_PER_BATCH: usize = 32; + +/// Compute chunking metadata: returns (total_chunks, chunk_iterator). +fn chunk_data(data: &[u8]) -> impl Iterator { + let total = data.len().div_ceil(CHUNK_SIZE).max(1) as u32; + data.chunks(CHUNK_SIZE) + .chain(if data.is_empty() { + // Yield a single empty slice for empty payloads + Some([].as_slice()).into_iter() + } else { + None.into_iter() + }) + .enumerate() + .map(move |(i, chunk)| (i as u32, total, chunk)) +} + +/// Split a serialized request payload into `StreamChunk` client request variants. +pub fn chunk_request(data: Vec, stream_id: u32) -> Vec> { + chunk_data(&data) + .map(|(index, total, chunk)| ClientRequest::StreamChunk { + stream_id, + index, + total, + data: chunk.to_vec(), + }) + .collect() +} + +/// Split a serialized response payload into `StreamChunk` host response variants. +pub fn chunk_response(data: Vec, stream_id: u32) -> Vec { + chunk_data(&data) + .map(|(index, total, chunk)| HostResponse::StreamChunk { + stream_id, + index, + total, + data: chunk.to_vec(), + }) + .collect() +} + +#[derive(Debug, thiserror::Error)] +pub enum StreamError { + #[error("total_chunks is zero")] + ZeroTotalChunks, + #[error("total_chunks {total} exceeds maximum {max}")] + TotalChunksTooLarge { total: u32, max: u32 }, + #[error("total_chunks mismatch for stream {stream_id} (expected {expected}, got {actual})")] + TotalChunksMismatch { + stream_id: u32, + expected: u32, + actual: u32, + }, + #[error("duplicate chunk index {index} for stream {stream_id}")] + DuplicateChunk { stream_id: u32, index: u32 }, + #[error("chunk index {index} out of range for stream {stream_id} (total {total})")] + IndexOutOfRange { + stream_id: u32, + index: u32, + total: u32, + }, + #[error("too many concurrent streams ({count}), maximum is {max}")] + TooManyConcurrentStreams { count: usize, max: usize }, + #[error("stream channel closed")] + ChannelClosed, + #[error("stream truncated: received {received} of {expected} bytes")] + Truncated { received: u64, expected: u64 }, +} + +struct StreamState { + chunks: Vec>>, + total: u32, + received: u32, +} + +/// Reassembly buffer keyed by stream ID. Supports concurrent streams. +pub struct ReassemblyBuffer { + streams: HashMap, +} + +impl ReassemblyBuffer { + pub fn new() -> Self { + Self { + streams: HashMap::new(), + } + } + + /// Feed a chunk into the buffer. Returns the fully reassembled payload + /// when all chunks for a stream have arrived. + pub fn receive_chunk( + &mut self, + stream_id: u32, + index: u32, + total: u32, + data: &[u8], + ) -> Result>, StreamError> { + if total == 0 { + return Err(StreamError::ZeroTotalChunks); + } + if total > MAX_TOTAL_CHUNKS { + return Err(StreamError::TotalChunksTooLarge { + total, + max: MAX_TOTAL_CHUNKS, + }); + } + if index >= total { + return Err(StreamError::IndexOutOfRange { + stream_id, + index, + total, + }); + } + + // Reject new streams when the concurrent stream limit is reached. + if !self.streams.contains_key(&stream_id) && self.streams.len() >= MAX_CONCURRENT_STREAMS { + return Err(StreamError::TooManyConcurrentStreams { + count: self.streams.len(), + max: MAX_CONCURRENT_STREAMS, + }); + } + + let state = self + .streams + .entry(stream_id) + .or_insert_with(|| StreamState { + chunks: vec![None; total as usize], + total, + received: 0, + }); + + if state.total != total { + return Err(StreamError::TotalChunksMismatch { + stream_id, + expected: state.total, + actual: total, + }); + } + + let idx = index as usize; + if state.chunks[idx].is_some() { + return Err(StreamError::DuplicateChunk { stream_id, index }); + } + + state.chunks[idx] = Some(data.to_vec()); + state.received += 1; + + if state.received == state.total { + let state = self.streams.remove(&stream_id).unwrap(); + let exact_len: usize = state.chunks.iter().flatten().map(|c| c.len()).sum(); + let mut result = Vec::with_capacity(exact_len); + for chunk in state.chunks.into_iter().flatten() { + result.extend_from_slice(&chunk); + } + Ok(Some(result)) + } else { + Ok(None) + } + } +} + +impl Default for ReassemblyBuffer { + fn default() -> Self { + Self::new() + } +} + +// --- App-level streaming API (requires tokio) --- + +#[cfg(all(feature = "net", not(target_family = "wasm")))] +pub use app_stream::*; + +#[cfg(all(feature = "net", not(target_family = "wasm")))] +mod app_stream { + use super::StreamError; + use crate::client_api::client_events::StreamContent; + use std::pin::Pin; + use std::task::{Context, Poll}; + use tokio::sync::mpsc; + + /// Client-side handle for consuming a WebSocket stream incrementally. + /// + /// Created when the client receives a [`HostResponse::StreamHeader`] from the node. + /// The corresponding [`WsStreamSender`] feeds chunks into this handle as they arrive. + /// + /// Two consumption modes: + /// - [`into_stream()`](WsStreamHandle::into_stream) — incremental async `Stream>` + /// - [`assemble()`](WsStreamHandle::assemble) — blocking wait for the complete payload + pub struct WsStreamHandle { + content: StreamContent, + total_bytes: u64, + chunk_rx: mpsc::UnboundedReceiver>, + } + + impl WsStreamHandle { + /// Metadata describing what is being streamed. + pub fn content(&self) -> &StreamContent { + &self.content + } + + /// Total expected bytes across all chunks. + pub fn total_bytes(&self) -> u64 { + self.total_bytes + } + + /// Consume chunks incrementally as an async `Stream`. + pub fn into_stream(self) -> WsStream { + WsStream { + chunk_rx: self.chunk_rx, + } + } + + /// Wait for all chunks and return the fully reassembled payload. + /// + /// Returns [`StreamError::Truncated`] if the sender closes before all + /// expected bytes have been delivered. + pub async fn assemble(mut self) -> Result, StreamError> { + // Cap pre-allocation to avoid OOM from a malicious total_bytes header. + const MAX_PREALLOC: usize = 50 * 1024 * 1024; + let mut buf = Vec::with_capacity((self.total_bytes as usize).min(MAX_PREALLOC)); + while let Some(chunk) = self.chunk_rx.recv().await { + buf.extend_from_slice(&chunk); + } + if (buf.len() as u64) < self.total_bytes { + return Err(StreamError::Truncated { + received: buf.len() as u64, + expected: self.total_bytes, + }); + } + Ok(buf) + } + } + + /// Async stream of chunk data produced by [`WsStreamHandle::into_stream()`]. + pub struct WsStream { + chunk_rx: mpsc::UnboundedReceiver>, + } + + impl futures::Stream for WsStream { + type Item = Vec; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.chunk_rx.poll_recv(cx) + } + } + + /// Sender side — held by the request handler to feed chunks into the handle. + pub struct WsStreamSender { + chunk_tx: mpsc::UnboundedSender>, + } + + impl WsStreamSender { + /// Send a chunk of data to the corresponding [`WsStreamHandle`]. + pub fn send_chunk(&self, data: Vec) -> Result<(), StreamError> { + self.chunk_tx + .send(data) + .map_err(|_| StreamError::ChannelClosed) + } + } + + /// Create a paired (handle, sender) for a new stream. + pub fn ws_stream_pair( + content: StreamContent, + total_bytes: u64, + ) -> (WsStreamHandle, WsStreamSender) { + let (tx, rx) = mpsc::unbounded_channel(); + ( + WsStreamHandle { + content, + total_bytes, + chunk_rx: rx, + }, + WsStreamSender { chunk_tx: tx }, + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn chunk_request_small() { + let data = vec![42u8; 1024]; + let chunks = chunk_request(data.clone(), 1); + assert_eq!(chunks.len(), 1); + match &chunks[0] { + ClientRequest::StreamChunk { + stream_id, + index, + total, + data: chunk_data, + } => { + assert_eq!(*stream_id, 1); + assert_eq!(*index, 0); + assert_eq!(*total, 1); + assert_eq!(chunk_data, &data); + } + _ => panic!("expected StreamChunk"), + } + } + + #[test] + fn chunk_request_large_roundtrip() { + let data: Vec = (0..600 * 1024).map(|i| (i % 256) as u8).collect(); + let chunks = chunk_request(data.clone(), 42); + assert_eq!(chunks.len(), 3); + + let mut buf = ReassemblyBuffer::new(); + for chunk in &chunks { + if let ClientRequest::StreamChunk { + stream_id, + index, + total, + data: chunk_data, + } = chunk + { + if let Some(result) = buf + .receive_chunk(*stream_id, *index, *total, chunk_data) + .unwrap() + { + assert_eq!(result, data); + } + } + } + } + + #[test] + fn chunk_response_roundtrip() { + let data = vec![0xAB; CHUNK_SIZE * 2]; + let chunks = chunk_response(data.clone(), 7); + assert_eq!(chunks.len(), 2); + + let mut buf = ReassemblyBuffer::new(); + for chunk in &chunks { + if let HostResponse::StreamChunk { + stream_id, + index, + total, + data: chunk_data, + } = chunk + { + if let Some(result) = buf + .receive_chunk(*stream_id, *index, *total, chunk_data) + .unwrap() + { + assert_eq!(result, data); + } + } + } + } + + #[test] + fn chunk_empty() { + let chunks = chunk_request(Vec::new(), 1); + assert_eq!(chunks.len(), 1); + match &chunks[0] { + ClientRequest::StreamChunk { total, data, .. } => { + assert_eq!(*total, 1); + assert!(data.is_empty()); + } + _ => panic!("expected StreamChunk"), + } + } + + #[test] + fn reassembly_resets_after_completion() { + let data_a = vec![0xAA; CHUNK_SIZE * 2]; + let data_b = vec![0xBB; CHUNK_SIZE * 3]; + + let mut buf = ReassemblyBuffer::new(); + + for chunk in &chunk_request(data_a.clone(), 1) { + if let ClientRequest::StreamChunk { + stream_id, + index, + total, + data, + } = chunk + { + if let Some(r) = buf.receive_chunk(*stream_id, *index, *total, data).unwrap() { + assert_eq!(r, data_a); + } + } + } + + for chunk in &chunk_request(data_b.clone(), 2) { + if let ClientRequest::StreamChunk { + stream_id, + index, + total, + data, + } = chunk + { + if let Some(r) = buf.receive_chunk(*stream_id, *index, *total, data).unwrap() { + assert_eq!(r, data_b); + } + } + } + } + + #[test] + fn zero_total_chunks_error() { + let mut buf = ReassemblyBuffer::new(); + let err = buf.receive_chunk(1, 0, 0, &[1, 2, 3]).unwrap_err(); + assert!(matches!(err, StreamError::ZeroTotalChunks)); + } + + #[test] + fn total_too_large_error() { + let mut buf = ReassemblyBuffer::new(); + let err = buf.receive_chunk(1, 0, 1000, &[1, 2, 3]).unwrap_err(); + assert!(matches!(err, StreamError::TotalChunksTooLarge { .. })); + } + + #[test] + fn total_mismatch_error() { + let mut buf = ReassemblyBuffer::new(); + buf.receive_chunk(1, 0, 3, &[1, 2, 3]).unwrap(); + let err = buf.receive_chunk(1, 1, 5, &[4, 5, 6]).unwrap_err(); + assert!(matches!(err, StreamError::TotalChunksMismatch { .. })); + } + + #[test] + fn duplicate_chunk_error() { + let mut buf = ReassemblyBuffer::new(); + buf.receive_chunk(1, 0, 3, &[1, 2, 3]).unwrap(); + let err = buf.receive_chunk(1, 0, 3, &[4, 5, 6]).unwrap_err(); + assert!(matches!( + err, + StreamError::DuplicateChunk { + stream_id: 1, + index: 0 + } + )); + } + + #[test] + fn index_out_of_range_error() { + let mut buf = ReassemblyBuffer::new(); + let err = buf.receive_chunk(1, 5, 3, &[1, 2, 3]).unwrap_err(); + assert!(matches!( + err, + StreamError::IndexOutOfRange { + index: 5, + total: 3, + .. + } + )); + } + + #[test] + fn too_many_concurrent_streams_error() { + let mut buf = ReassemblyBuffer::new(); + for i in 0..MAX_CONCURRENT_STREAMS as u32 { + buf.receive_chunk(i, 0, 2, &[1]).unwrap(); + } + let err = buf + .receive_chunk(MAX_CONCURRENT_STREAMS as u32, 0, 2, &[1]) + .unwrap_err(); + assert!(matches!(err, StreamError::TooManyConcurrentStreams { .. })); + } + + #[cfg(all(feature = "net", not(target_family = "wasm")))] + mod stream_handle_tests { + use super::super::*; + use crate::client_api::client_events::StreamContent; + use crate::prelude::{ContractCode, Parameters}; + use futures::StreamExt; + + #[tokio::test] + async fn ws_stream_assemble() { + let code = ContractCode::from(vec![1, 2, 3]); + let key = + crate::prelude::ContractKey::from_params_and_code(Parameters::from(vec![]), &code); + let content = StreamContent::GetResponse { + key, + includes_contract: false, + }; + let data = vec![0xAB; CHUNK_SIZE * 3]; + let (handle, sender) = ws_stream_pair(content, data.len() as u64); + + // Feed chunks in a background task + let data_clone = data.clone(); + tokio::spawn(async move { + for chunk in data_clone.chunks(CHUNK_SIZE) { + sender.send_chunk(chunk.to_vec()).unwrap(); + } + // sender dropped here → handle's rx will close + }); + + let result = handle.assemble().await.unwrap(); + assert_eq!(result, data); + } + + #[tokio::test] + async fn ws_stream_incremental() { + let content = StreamContent::Raw; + let data = vec![0xCD; CHUNK_SIZE * 2]; + let (handle, sender) = ws_stream_pair(content, data.len() as u64); + + tokio::spawn(async move { + for chunk in data.chunks(CHUNK_SIZE) { + sender.send_chunk(chunk.to_vec()).unwrap(); + } + }); + + let mut stream = handle.into_stream(); + let mut collected = Vec::new(); + while let Some(chunk) = stream.next().await { + collected.extend_from_slice(&chunk); + } + assert_eq!(collected.len(), CHUNK_SIZE * 2); + assert!(collected.iter().all(|&b| b == 0xCD)); + } + } +} diff --git a/rust/src/client_api/ws_streaming.rs b/rust/src/client_api/ws_streaming.rs deleted file mode 100644 index d4097d7..0000000 --- a/rust/src/client_api/ws_streaming.rs +++ /dev/null @@ -1,326 +0,0 @@ -//! WebSocket message streaming protocol for large payloads (client-side). -//! -//! Protocol constants are duplicated from the server-side module to avoid -//! adding a crate dependency from stdlib to core. -//! -//! All WebSocket messages use a 1-byte type prefix: -//! - `0x00` + payload = complete message -//! - `0x01` + 4 bytes (total_chunks LE) + payload = stream chunk -//! -//! The chunk header is 5 bytes (`CHUNK_HEADER_SIZE`): the 1-byte type prefix -//! followed by a single little-endian `u32` total_chunks field. - -const MSG_COMPLETE: u8 = 0x00; -const MSG_CHUNK: u8 = 0x01; - -/// 1 (type) + 4 (total_chunks). -pub const CHUNK_HEADER_SIZE: usize = 5; - -/// Default chunk payload size: 256 KiB. -pub const DEFAULT_CHUNK_SIZE: usize = 256 * 1024; - -/// Messages larger than this threshold are chunked. -pub const CHUNK_THRESHOLD: usize = 512 * 1024; - -/// Maximum `total_chunks` accepted from the wire. -/// Based on MAX_STATE_SIZE (50 MiB) / DEFAULT_CHUNK_SIZE. -const MAX_TOTAL_CHUNKS: u32 = 256; - -/// Parsed streaming message. -#[derive(Debug)] -pub enum StreamMessage<'a> { - Complete(&'a [u8]), - Chunk { - total_chunks: u32, - payload: &'a [u8], - }, -} - -#[derive(Debug, thiserror::Error)] -pub enum StreamError { - #[error("message too short: expected at least {expected} bytes, got {actual}")] - MessageTooShort { expected: usize, actual: usize }, - #[error("unknown message type prefix: 0x{0:02x}")] - UnknownMessageType(u8), - #[error("total_chunks is zero")] - ZeroTotalChunks, - #[error("total_chunks {total_chunks} exceeds maximum {max}")] - TotalChunksTooLarge { total_chunks: u32, max: u32 }, - #[error("total_chunks mismatch (expected {expected}, got {actual})")] - TotalChunksMismatch { expected: u32, actual: u32 }, -} - -/// Wraps a serialized payload as a complete (non-chunked) streaming message. -pub fn wrap_complete(data: Vec) -> Vec { - let mut buf = Vec::with_capacity(1 + data.len()); - buf.push(MSG_COMPLETE); - buf.extend_from_slice(&data); - buf -} - -/// Splits a serialized payload into chunked streaming messages. -pub fn chunk_payload(data: &[u8]) -> Vec> { - if data.is_empty() { - let mut buf = Vec::with_capacity(CHUNK_HEADER_SIZE); - buf.push(MSG_CHUNK); - buf.extend_from_slice(&1u32.to_le_bytes()); - return vec![buf]; - } - - let total_chunks = data.len().div_ceil(DEFAULT_CHUNK_SIZE); - let mut chunks = Vec::with_capacity(total_chunks); - - for chunk_data in data.chunks(DEFAULT_CHUNK_SIZE) { - let mut buf = Vec::with_capacity(CHUNK_HEADER_SIZE + chunk_data.len()); - buf.push(MSG_CHUNK); - buf.extend_from_slice(&(total_chunks as u32).to_le_bytes()); - buf.extend_from_slice(chunk_data); - chunks.push(buf); - } - - chunks -} - -/// Parses a raw WebSocket binary message into a streaming protocol message. -pub fn parse_message(data: &[u8]) -> Result, StreamError> { - if data.is_empty() { - return Err(StreamError::MessageTooShort { - expected: 1, - actual: 0, - }); - } - - match data[0] { - MSG_COMPLETE => Ok(StreamMessage::Complete(&data[1..])), - MSG_CHUNK => { - if data.len() < CHUNK_HEADER_SIZE { - return Err(StreamError::MessageTooShort { - expected: CHUNK_HEADER_SIZE, - actual: data.len(), - }); - } - let total_chunks = u32::from_le_bytes([data[1], data[2], data[3], data[4]]); - - if total_chunks == 0 { - return Err(StreamError::ZeroTotalChunks); - } - if total_chunks > MAX_TOTAL_CHUNKS { - return Err(StreamError::TotalChunksTooLarge { - total_chunks, - max: MAX_TOTAL_CHUNKS, - }); - } - - Ok(StreamMessage::Chunk { - total_chunks, - payload: &data[CHUNK_HEADER_SIZE..], - }) - } - other => Err(StreamError::UnknownMessageType(other)), - } -} - -/// Sequential reassembly buffer for chunked streams. -/// -/// TCP guarantees ordered delivery and the select loop serializes message sends, -/// so chunks always arrive in order. This buffer simply appends incoming chunks -/// and returns the complete payload when all arrive. -pub struct ChunkReassemblyBuffer { - data: Vec, - total_chunks: u32, - received: u32, -} - -impl ChunkReassemblyBuffer { - pub fn new() -> Self { - Self { - data: Vec::new(), - total_chunks: 0, - received: 0, - } - } - - /// Receives a chunk and returns the fully reassembled payload when all chunks arrive. - /// - /// Returns `Ok(None)` if more chunks are needed. - pub fn receive_chunk( - &mut self, - total_chunks: u32, - payload: &[u8], - ) -> Result>, StreamError> { - if self.received == 0 { - self.total_chunks = total_chunks; - self.data - .reserve(total_chunks as usize * DEFAULT_CHUNK_SIZE); - } else if self.total_chunks != total_chunks { - return Err(StreamError::TotalChunksMismatch { - expected: self.total_chunks, - actual: total_chunks, - }); - } - - self.data.extend_from_slice(payload); - self.received += 1; - - if self.received == self.total_chunks { - let result = std::mem::take(&mut self.data); - self.received = 0; - self.total_chunks = 0; - Ok(Some(result)) - } else { - Ok(None) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn wrap_complete_roundtrip() { - let data = vec![1, 2, 3, 4, 5]; - let wrapped = wrap_complete(data.clone()); - assert_eq!(wrapped[0], MSG_COMPLETE); - match parse_message(&wrapped).unwrap() { - StreamMessage::Complete(payload) => assert_eq!(payload, &data[..]), - StreamMessage::Chunk { .. } => panic!("expected Complete"), - } - } - - #[test] - fn chunk_small_payload_roundtrip() { - let data = vec![42u8; 1024]; - let chunks = chunk_payload(&data); - assert_eq!(chunks.len(), 1); - - match parse_message(&chunks[0]).unwrap() { - StreamMessage::Chunk { - total_chunks, - payload, - } => { - assert_eq!(total_chunks, 1); - assert_eq!(payload, &data[..]); - } - StreamMessage::Complete(_) => panic!("expected Chunk"), - } - } - - #[test] - fn chunk_large_payload_roundtrip() { - let data: Vec = (0..600 * 1024).map(|i| (i % 256) as u8).collect(); - let chunks = chunk_payload(&data); - assert_eq!(chunks.len(), 3); - - let mut reassembly = ChunkReassemblyBuffer::new(); - for (i, chunk) in chunks.iter().enumerate() { - match parse_message(chunk).unwrap() { - StreamMessage::Chunk { - total_chunks, - payload, - } => { - let result = reassembly.receive_chunk(total_chunks, payload).unwrap(); - if i < 2 { - assert!(result.is_none()); - } else { - assert_eq!(result.unwrap(), data); - } - } - StreamMessage::Complete(_) => panic!("expected Chunk"), - } - } - } - - #[test] - fn chunk_empty_payload() { - let chunks = chunk_payload(&[]); - assert_eq!(chunks.len(), 1); - - match parse_message(&chunks[0]).unwrap() { - StreamMessage::Chunk { - total_chunks, - payload, - } => { - assert_eq!(total_chunks, 1); - assert!(payload.is_empty()); - - let mut reassembly = ChunkReassemblyBuffer::new(); - let result = reassembly.receive_chunk(total_chunks, payload).unwrap(); - assert_eq!(result.unwrap(), Vec::::new()); - } - StreamMessage::Complete(_) => panic!("expected Chunk"), - } - } - - #[test] - fn parse_errors() { - assert!(matches!( - parse_message(&[]).unwrap_err(), - StreamError::MessageTooShort { .. } - )); - assert!(matches!( - parse_message(&[0xFF, 1, 2, 3]).unwrap_err(), - StreamError::UnknownMessageType(0xFF) - )); - assert!(matches!( - parse_message(&[MSG_CHUNK, 0, 0]).unwrap_err(), - StreamError::MessageTooShort { .. } - )); - - let mut zero_chunks = vec![MSG_CHUNK]; - zero_chunks.extend_from_slice(&0u32.to_le_bytes()); - assert!(matches!( - parse_message(&zero_chunks).unwrap_err(), - StreamError::ZeroTotalChunks - )); - - let mut too_large = vec![MSG_CHUNK]; - too_large.extend_from_slice(&1000u32.to_le_bytes()); - assert!(matches!( - parse_message(&too_large).unwrap_err(), - StreamError::TotalChunksTooLarge { .. } - )); - } - - #[test] - fn total_chunks_mismatch() { - let mut reassembly = ChunkReassemblyBuffer::new(); - reassembly.receive_chunk(3, &[1, 2, 3]).unwrap(); - assert!(matches!( - reassembly.receive_chunk(5, &[4, 5, 6]).unwrap_err(), - StreamError::TotalChunksMismatch { .. } - )); - } - - #[test] - fn reassembly_resets_after_completion() { - let data_a = vec![0xAA; DEFAULT_CHUNK_SIZE * 2]; - let data_b = vec![0xBB; DEFAULT_CHUNK_SIZE * 3]; - - let mut reassembly = ChunkReassemblyBuffer::new(); - - for chunk in &chunk_payload(&data_a) { - if let StreamMessage::Chunk { - total_chunks, - payload, - } = parse_message(chunk).unwrap() - { - if let Some(r) = reassembly.receive_chunk(total_chunks, payload).unwrap() { - assert_eq!(r, data_a); - } - } - } - - for chunk in &chunk_payload(&data_b) { - if let StreamMessage::Chunk { - total_chunks, - payload, - } = parse_message(chunk).unwrap() - { - if let Some(r) = reassembly.receive_chunk(total_chunks, payload).unwrap() { - assert_eq!(r, data_b); - } - } - } - } -} diff --git a/rust/src/generated/client_request_generated.rs b/rust/src/generated/client_request_generated.rs index 2540ed1..e25782b 100644 --- a/rust/src/generated/client_request_generated.rs +++ b/rust/src/generated/client_request_generated.rs @@ -435,18 +435,19 @@ pub mod client_request { since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] - pub const ENUM_MAX_CLIENT_REQUEST_TYPE: u8 = 4; + pub const ENUM_MAX_CLIENT_REQUEST_TYPE: u8 = 5; #[deprecated( since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] #[allow(non_camel_case_types)] - pub const ENUM_VALUES_CLIENT_REQUEST_TYPE: [ClientRequestType; 5] = [ + pub const ENUM_VALUES_CLIENT_REQUEST_TYPE: [ClientRequestType; 6] = [ ClientRequestType::NONE, ClientRequestType::ContractRequest, ClientRequestType::DelegateRequest, ClientRequestType::Disconnect, ClientRequestType::Authenticate, + ClientRequestType::StreamChunk, ]; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] @@ -459,15 +460,17 @@ pub mod client_request { pub const DelegateRequest: Self = Self(2); pub const Disconnect: Self = Self(3); pub const Authenticate: Self = Self(4); + pub const StreamChunk: Self = Self(5); pub const ENUM_MIN: u8 = 0; - pub const ENUM_MAX: u8 = 4; + pub const ENUM_MAX: u8 = 5; pub const ENUM_VALUES: &'static [Self] = &[ Self::NONE, Self::ContractRequest, Self::DelegateRequest, Self::Disconnect, Self::Authenticate, + Self::StreamChunk, ]; /// Returns the variant's name or "" if unknown. pub fn variant_name(self) -> Option<&'static str> { @@ -477,6 +480,7 @@ pub mod client_request { Self::DelegateRequest => Some("DelegateRequest"), Self::Disconnect => Some("Disconnect"), Self::Authenticate => Some("Authenticate"), + Self::StreamChunk => Some("StreamChunk"), _ => None, } } @@ -4172,6 +4176,191 @@ pub mod client_request { ds.finish() } } + pub enum StreamChunkOffset {} + #[derive(Copy, Clone, PartialEq)] + + pub struct StreamChunk<'a> { + pub _tab: flatbuffers::Table<'a>, + } + + impl<'a> flatbuffers::Follow<'a> for StreamChunk<'a> { + type Inner = StreamChunk<'a>; + #[inline] + unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { + Self { + _tab: flatbuffers::Table::new(buf, loc), + } + } + } + + impl<'a> StreamChunk<'a> { + pub const VT_STREAM_ID: flatbuffers::VOffsetT = 4; + pub const VT_INDEX: flatbuffers::VOffsetT = 6; + pub const VT_TOTAL: flatbuffers::VOffsetT = 8; + pub const VT_DATA: flatbuffers::VOffsetT = 10; + + #[inline] + pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { + StreamChunk { _tab: table } + } + #[allow(unused_mut)] + pub fn create< + 'bldr: 'args, + 'args: 'mut_bldr, + 'mut_bldr, + A: flatbuffers::Allocator + 'bldr, + >( + _fbb: &'mut_bldr mut flatbuffers::FlatBufferBuilder<'bldr, A>, + args: &'args StreamChunkArgs<'args>, + ) -> flatbuffers::WIPOffset> { + let mut builder = StreamChunkBuilder::new(_fbb); + if let Some(x) = args.data { + builder.add_data(x); + } + builder.add_total(args.total); + builder.add_index(args.index); + builder.add_stream_id(args.stream_id); + builder.finish() + } + + #[inline] + pub fn stream_id(&self) -> u32 { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::(StreamChunk::VT_STREAM_ID, Some(0)) + .unwrap() + } + } + #[inline] + pub fn index(&self) -> u32 { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::(StreamChunk::VT_INDEX, Some(0)) + .unwrap() + } + } + #[inline] + pub fn total(&self) -> u32 { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::(StreamChunk::VT_TOTAL, Some(0)) + .unwrap() + } + } + #[inline] + pub fn data(&self) -> flatbuffers::Vector<'a, u8> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>>( + StreamChunk::VT_DATA, + None, + ) + .unwrap() + } + } + } + + impl flatbuffers::Verifiable for StreamChunk<'_> { + #[inline] + fn run_verifier( + v: &mut flatbuffers::Verifier, + pos: usize, + ) -> Result<(), flatbuffers::InvalidFlatbuffer> { + use self::flatbuffers::Verifiable; + v.visit_table(pos)? + .visit_field::("stream_id", Self::VT_STREAM_ID, false)? + .visit_field::("index", Self::VT_INDEX, false)? + .visit_field::("total", Self::VT_TOTAL, false)? + .visit_field::>>( + "data", + Self::VT_DATA, + true, + )? + .finish(); + Ok(()) + } + } + pub struct StreamChunkArgs<'a> { + pub stream_id: u32, + pub index: u32, + pub total: u32, + pub data: Option>>, + } + impl<'a> Default for StreamChunkArgs<'a> { + #[inline] + fn default() -> Self { + StreamChunkArgs { + stream_id: 0, + index: 0, + total: 0, + data: None, // required field + } + } + } + + pub struct StreamChunkBuilder<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> { + fbb_: &'b mut flatbuffers::FlatBufferBuilder<'a, A>, + start_: flatbuffers::WIPOffset, + } + impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> StreamChunkBuilder<'a, 'b, A> { + #[inline] + pub fn add_stream_id(&mut self, stream_id: u32) { + self.fbb_ + .push_slot::(StreamChunk::VT_STREAM_ID, stream_id, 0); + } + #[inline] + pub fn add_index(&mut self, index: u32) { + self.fbb_.push_slot::(StreamChunk::VT_INDEX, index, 0); + } + #[inline] + pub fn add_total(&mut self, total: u32) { + self.fbb_.push_slot::(StreamChunk::VT_TOTAL, total, 0); + } + #[inline] + pub fn add_data(&mut self, data: flatbuffers::WIPOffset>) { + self.fbb_ + .push_slot_always::>(StreamChunk::VT_DATA, data); + } + #[inline] + pub fn new( + _fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>, + ) -> StreamChunkBuilder<'a, 'b, A> { + let start = _fbb.start_table(); + StreamChunkBuilder { + fbb_: _fbb, + start_: start, + } + } + #[inline] + pub fn finish(self) -> flatbuffers::WIPOffset> { + let o = self.fbb_.end_table(self.start_); + self.fbb_.required(o, StreamChunk::VT_DATA, "data"); + flatbuffers::WIPOffset::new(o.value()) + } + } + + impl core::fmt::Debug for StreamChunk<'_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let mut ds = f.debug_struct("StreamChunk"); + ds.field("stream_id", &self.stream_id()); + ds.field("index", &self.index()); + ds.field("total", &self.total()); + ds.field("data", &self.data()); + ds.finish() + } + } pub enum ClientRequestOffset {} #[derive(Copy, Clone, PartialEq)] @@ -4298,6 +4487,20 @@ pub mod client_request { None } } + + #[inline] + #[allow(non_snake_case)] + pub fn client_request_as_stream_chunk(&self) -> Option> { + if self.client_request_type() == ClientRequestType::StreamChunk { + let u = self.client_request(); + // Safety: + // Created from a valid Table for this object + // Which contains a valid union in this slot + Some(unsafe { StreamChunk::init_from_table(u) }) + } else { + None + } + } } impl flatbuffers::Verifiable for ClientRequest<'_> { @@ -4335,6 +4538,11 @@ pub mod client_request { "ClientRequestType::Authenticate", pos, ), + ClientRequestType::StreamChunk => v + .verify_union_variant::>( + "ClientRequestType::StreamChunk", + pos, + ), _ => Ok(()), }, )? @@ -4443,6 +4651,16 @@ pub mod client_request { ) } } + ClientRequestType::StreamChunk => { + if let Some(x) = self.client_request_as_stream_chunk() { + ds.field("client_request", &x) + } else { + ds.field( + "client_request", + &"InvalidFlatbuffer: Union discriminant does not match value.", + ) + } + } _ => { let x: Option<()> = None; ds.field("client_request", &x) diff --git a/rust/src/generated/host_response_generated.rs b/rust/src/generated/host_response_generated.rs index 3077c50..91ce18d 100644 --- a/rust/src/generated/host_response_generated.rs +++ b/rust/src/generated/host_response_generated.rs @@ -246,19 +246,20 @@ pub mod host_response { since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] - pub const ENUM_MAX_HOST_RESPONSE_TYPE: u8 = 5; + pub const ENUM_MAX_HOST_RESPONSE_TYPE: u8 = 6; #[deprecated( since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] #[allow(non_camel_case_types)] - pub const ENUM_VALUES_HOST_RESPONSE_TYPE: [HostResponseType; 6] = [ + pub const ENUM_VALUES_HOST_RESPONSE_TYPE: [HostResponseType; 7] = [ HostResponseType::NONE, HostResponseType::ContractResponse, HostResponseType::DelegateResponse, HostResponseType::GenerateRandData, HostResponseType::Ok, HostResponseType::Error, + HostResponseType::StreamChunk, ]; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] @@ -272,9 +273,10 @@ pub mod host_response { pub const GenerateRandData: Self = Self(3); pub const Ok: Self = Self(4); pub const Error: Self = Self(5); + pub const StreamChunk: Self = Self(6); pub const ENUM_MIN: u8 = 0; - pub const ENUM_MAX: u8 = 5; + pub const ENUM_MAX: u8 = 6; pub const ENUM_VALUES: &'static [Self] = &[ Self::NONE, Self::ContractResponse, @@ -282,6 +284,7 @@ pub mod host_response { Self::GenerateRandData, Self::Ok, Self::Error, + Self::StreamChunk, ]; /// Returns the variant's name or "" if unknown. pub fn variant_name(self) -> Option<&'static str> { @@ -292,6 +295,7 @@ pub mod host_response { Self::GenerateRandData => Some("GenerateRandData"), Self::Ok => Some("Ok"), Self::Error => Some("Error"), + Self::StreamChunk => Some("StreamChunk"), _ => None, } } @@ -2934,6 +2938,191 @@ pub mod host_response { ds.finish() } } + pub enum StreamChunkOffset {} + #[derive(Copy, Clone, PartialEq)] + + pub struct StreamChunk<'a> { + pub _tab: flatbuffers::Table<'a>, + } + + impl<'a> flatbuffers::Follow<'a> for StreamChunk<'a> { + type Inner = StreamChunk<'a>; + #[inline] + unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { + Self { + _tab: flatbuffers::Table::new(buf, loc), + } + } + } + + impl<'a> StreamChunk<'a> { + pub const VT_STREAM_ID: flatbuffers::VOffsetT = 4; + pub const VT_INDEX: flatbuffers::VOffsetT = 6; + pub const VT_TOTAL: flatbuffers::VOffsetT = 8; + pub const VT_DATA: flatbuffers::VOffsetT = 10; + + #[inline] + pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { + StreamChunk { _tab: table } + } + #[allow(unused_mut)] + pub fn create< + 'bldr: 'args, + 'args: 'mut_bldr, + 'mut_bldr, + A: flatbuffers::Allocator + 'bldr, + >( + _fbb: &'mut_bldr mut flatbuffers::FlatBufferBuilder<'bldr, A>, + args: &'args StreamChunkArgs<'args>, + ) -> flatbuffers::WIPOffset> { + let mut builder = StreamChunkBuilder::new(_fbb); + if let Some(x) = args.data { + builder.add_data(x); + } + builder.add_total(args.total); + builder.add_index(args.index); + builder.add_stream_id(args.stream_id); + builder.finish() + } + + #[inline] + pub fn stream_id(&self) -> u32 { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::(StreamChunk::VT_STREAM_ID, Some(0)) + .unwrap() + } + } + #[inline] + pub fn index(&self) -> u32 { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::(StreamChunk::VT_INDEX, Some(0)) + .unwrap() + } + } + #[inline] + pub fn total(&self) -> u32 { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::(StreamChunk::VT_TOTAL, Some(0)) + .unwrap() + } + } + #[inline] + pub fn data(&self) -> flatbuffers::Vector<'a, u8> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>>( + StreamChunk::VT_DATA, + None, + ) + .unwrap() + } + } + } + + impl flatbuffers::Verifiable for StreamChunk<'_> { + #[inline] + fn run_verifier( + v: &mut flatbuffers::Verifier, + pos: usize, + ) -> Result<(), flatbuffers::InvalidFlatbuffer> { + use self::flatbuffers::Verifiable; + v.visit_table(pos)? + .visit_field::("stream_id", Self::VT_STREAM_ID, false)? + .visit_field::("index", Self::VT_INDEX, false)? + .visit_field::("total", Self::VT_TOTAL, false)? + .visit_field::>>( + "data", + Self::VT_DATA, + true, + )? + .finish(); + Ok(()) + } + } + pub struct StreamChunkArgs<'a> { + pub stream_id: u32, + pub index: u32, + pub total: u32, + pub data: Option>>, + } + impl<'a> Default for StreamChunkArgs<'a> { + #[inline] + fn default() -> Self { + StreamChunkArgs { + stream_id: 0, + index: 0, + total: 0, + data: None, // required field + } + } + } + + pub struct StreamChunkBuilder<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> { + fbb_: &'b mut flatbuffers::FlatBufferBuilder<'a, A>, + start_: flatbuffers::WIPOffset, + } + impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> StreamChunkBuilder<'a, 'b, A> { + #[inline] + pub fn add_stream_id(&mut self, stream_id: u32) { + self.fbb_ + .push_slot::(StreamChunk::VT_STREAM_ID, stream_id, 0); + } + #[inline] + pub fn add_index(&mut self, index: u32) { + self.fbb_.push_slot::(StreamChunk::VT_INDEX, index, 0); + } + #[inline] + pub fn add_total(&mut self, total: u32) { + self.fbb_.push_slot::(StreamChunk::VT_TOTAL, total, 0); + } + #[inline] + pub fn add_data(&mut self, data: flatbuffers::WIPOffset>) { + self.fbb_ + .push_slot_always::>(StreamChunk::VT_DATA, data); + } + #[inline] + pub fn new( + _fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>, + ) -> StreamChunkBuilder<'a, 'b, A> { + let start = _fbb.start_table(); + StreamChunkBuilder { + fbb_: _fbb, + start_: start, + } + } + #[inline] + pub fn finish(self) -> flatbuffers::WIPOffset> { + let o = self.fbb_.end_table(self.start_); + self.fbb_.required(o, StreamChunk::VT_DATA, "data"); + flatbuffers::WIPOffset::new(o.value()) + } + } + + impl core::fmt::Debug for StreamChunk<'_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let mut ds = f.debug_struct("StreamChunk"); + ds.field("stream_id", &self.stream_id()); + ds.field("index", &self.index()); + ds.field("total", &self.total()); + ds.field("data", &self.data()); + ds.finish() + } + } pub enum HostResponseOffset {} #[derive(Copy, Clone, PartialEq)] @@ -3074,6 +3263,20 @@ pub mod host_response { None } } + + #[inline] + #[allow(non_snake_case)] + pub fn response_as_stream_chunk(&self) -> Option> { + if self.response_type() == HostResponseType::StreamChunk { + let u = self.response(); + // Safety: + // Created from a valid Table for this object + // Which contains a valid union in this slot + Some(unsafe { StreamChunk::init_from_table(u) }) + } else { + None + } + } } impl flatbuffers::Verifiable for HostResponse<'_> { @@ -3116,6 +3319,11 @@ pub mod host_response { "HostResponseType::Error", pos, ), + HostResponseType::StreamChunk => v + .verify_union_variant::>( + "HostResponseType::StreamChunk", + pos, + ), _ => Ok(()), }, )? @@ -3231,6 +3439,16 @@ pub mod host_response { ) } } + HostResponseType::StreamChunk => { + if let Some(x) = self.response_as_stream_chunk() { + ds.field("response", &x) + } else { + ds.field( + "response", + &"InvalidFlatbuffer: Union discriminant does not match value.", + ) + } + } _ => { let x: Option<()> = None; ds.field("response", &x) diff --git a/schemas/flatbuffers/client_request.fbs b/schemas/flatbuffers/client_request.fbs index 10112f1..1bc4393 100644 --- a/schemas/flatbuffers/client_request.fbs +++ b/schemas/flatbuffers/client_request.fbs @@ -128,11 +128,19 @@ table Authenticate { token:string(required); } +table StreamChunk { + stream_id:uint32; + index:uint32; + total:uint32; + data:[ubyte](required); +} + union ClientRequestType { ContractRequest, DelegateRequest, Disconnect, - Authenticate + Authenticate, + StreamChunk } table ClientRequest { diff --git a/schemas/flatbuffers/host_response.fbs b/schemas/flatbuffers/host_response.fbs index f11acf7..8622eb6 100644 --- a/schemas/flatbuffers/host_response.fbs +++ b/schemas/flatbuffers/host_response.fbs @@ -85,6 +85,13 @@ table GenerateRandData { wrapped_state: [ubyte](required); } +table StreamChunk { + stream_id:uint32; + index:uint32; + total:uint32; + data:[ubyte](required); +} + table Ok { msg:string(required); } @@ -97,8 +104,9 @@ union HostResponseType { ContractResponse, DelegateResponse, GenerateRandData, + StreamChunk Ok, - Error + Error, } table HostResponse { From fee111bd03ccca733f6f20ed01354c9dbe7fe1ad Mon Sep 17 00:00:00 2001 From: Hector Santos Date: Sun, 1 Mar 2026 19:05:12 +0100 Subject: [PATCH 3/5] refactor WebSocket chunk handling to use Bytes for zero-copy efficiency --- rust/Cargo.toml | 1 + rust/src/client_api/browser.rs | 2 +- rust/src/client_api/client_events.rs | 7 +- rust/src/client_api/regular.rs | 2 +- rust/src/client_api/streaming.rs | 123 ++++++++++++++++++--------- 5 files changed, 88 insertions(+), 47 deletions(-) diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 20d4e70..0f6081f 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -11,6 +11,7 @@ repository = "https://github.com/freenet/freenet-stdlib" [dependencies] arbitrary = { version = "1", optional = true, features = ["derive"] } bincode = "1" +bytes = { version = "1", features = ["serde"] } byteorder = "1" blake3 = { version = "1", features = ["std", "traits-preview"] } bs58 = "0.5" diff --git a/rust/src/client_api/browser.rs b/rust/src/client_api/browser.rs index 5afd7ff..320ac67 100644 --- a/rust/src/client_api/browser.rs +++ b/rust/src/client_api/browser.rs @@ -84,7 +84,7 @@ impl WebApi { }) => { match reassembly_clone .borrow_mut() - .receive_chunk(stream_id, index, total, &data) + .receive_chunk(stream_id, index, total, data) { Ok(Some(complete)) => { let inner: HostResult = match bincode::deserialize(&complete) { diff --git a/rust/src/client_api/client_events.rs b/rust/src/client_api/client_events.rs index 30481fb..c54793e 100644 --- a/rust/src/client_api/client_events.rs +++ b/rust/src/client_api/client_events.rs @@ -1,3 +1,4 @@ +use bytes::Bytes; use flatbuffers::WIPOffset; use std::borrow::Cow; use std::fmt::Display; @@ -263,7 +264,7 @@ pub enum ClientRequest<'a> { stream_id: u32, index: u32, total: u32, - data: Vec, + data: Bytes, }, } @@ -380,7 +381,7 @@ impl ClientRequest<'_> { stream_id: chunk.stream_id(), index: chunk.index(), total: chunk.total(), - data: chunk.data().bytes().to_vec(), + data: Bytes::from(chunk.data().bytes().to_vec()), } } _ => { @@ -747,7 +748,7 @@ pub enum HostResponse { stream_id: u32, index: u32, total: u32, - data: Vec, + data: Bytes, }, /// Header message announcing the start of a streamed response. /// Sent before the corresponding [`StreamChunk`] messages so the client diff --git a/rust/src/client_api/regular.rs b/rust/src/client_api/regular.rs index bbaf10f..88d7ea4 100644 --- a/rust/src/client_api/regular.rs +++ b/rust/src/client_api/regular.rs @@ -348,7 +348,7 @@ async fn handle_response_payload( } else { // No StreamHeader seen → transparent reassembly (backward compat) match reassembly - .receive_chunk(stream_id, index, total, &data) + .receive_chunk(stream_id, index, total, data) .map_err(|e| Error::OtherError(e.into()))? { Some(complete) => { diff --git a/rust/src/client_api/streaming.rs b/rust/src/client_api/streaming.rs index 013c875..554d021 100644 --- a/rust/src/client_api/streaming.rs +++ b/rust/src/client_api/streaming.rs @@ -6,6 +6,8 @@ use std::collections::HashMap; +use bytes::Bytes; + use super::{ClientRequest, HostResponse}; /// Default chunk payload size: 256 KiB. @@ -25,40 +27,51 @@ pub const MAX_CONCURRENT_STREAMS: usize = 8; /// Chunks to send before yielding to the tokio runtime. pub const MAX_CHUNKS_PER_BATCH: usize = 32; -/// Compute chunking metadata: returns (total_chunks, chunk_iterator). -fn chunk_data(data: &[u8]) -> impl Iterator { +/// Zero-copy chunking: split `data` into (index, total, slice) tuples using `Bytes::slice()`. +fn chunk_bytes(data: &Bytes) -> Vec<(u32, u32, Bytes)> { let total = data.len().div_ceil(CHUNK_SIZE).max(1) as u32; - data.chunks(CHUNK_SIZE) - .chain(if data.is_empty() { - // Yield a single empty slice for empty payloads - Some([].as_slice()).into_iter() - } else { - None.into_iter() + if data.is_empty() { + return vec![(0, 1, Bytes::new())]; + } + (0..total as usize) + .map(|i| { + let start = i * CHUNK_SIZE; + let end = (start + CHUNK_SIZE).min(data.len()); + (i as u32, total, data.slice(start..end)) }) - .enumerate() - .map(move |(i, chunk)| (i as u32, total, chunk)) + .collect() } /// Split a serialized request payload into `StreamChunk` client request variants. +/// +/// Uses `Bytes::slice()` internally for zero-copy: each chunk shares the +/// original allocation via reference counting instead of copying. pub fn chunk_request(data: Vec, stream_id: u32) -> Vec> { - chunk_data(&data) + let data = Bytes::from(data); + chunk_bytes(&data) + .into_iter() .map(|(index, total, chunk)| ClientRequest::StreamChunk { stream_id, index, total, - data: chunk.to_vec(), + data: chunk, }) .collect() } /// Split a serialized response payload into `StreamChunk` host response variants. +/// +/// Uses `Bytes::slice()` internally for zero-copy: each chunk shares the +/// original allocation via reference counting instead of copying. pub fn chunk_response(data: Vec, stream_id: u32) -> Vec { - chunk_data(&data) + let data = Bytes::from(data); + chunk_bytes(&data) + .into_iter() .map(|(index, total, chunk)| HostResponse::StreamChunk { stream_id, index, total, - data: chunk.to_vec(), + data: chunk, }) .collect() } @@ -92,7 +105,7 @@ pub enum StreamError { } struct StreamState { - chunks: Vec>>, + chunks: Vec>, total: u32, received: u32, } @@ -116,7 +129,7 @@ impl ReassemblyBuffer { stream_id: u32, index: u32, total: u32, - data: &[u8], + data: Bytes, ) -> Result>, StreamError> { if total == 0 { return Err(StreamError::ZeroTotalChunks); @@ -165,7 +178,7 @@ impl ReassemblyBuffer { return Err(StreamError::DuplicateChunk { stream_id, index }); } - state.chunks[idx] = Some(data.to_vec()); + state.chunks[idx] = Some(data); state.received += 1; if state.received == state.total { @@ -197,6 +210,7 @@ pub use app_stream::*; mod app_stream { use super::StreamError; use crate::client_api::client_events::StreamContent; + use bytes::Bytes; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::sync::mpsc; @@ -207,12 +221,12 @@ mod app_stream { /// The corresponding [`WsStreamSender`] feeds chunks into this handle as they arrive. /// /// Two consumption modes: - /// - [`into_stream()`](WsStreamHandle::into_stream) — incremental async `Stream>` + /// - [`into_stream()`](WsStreamHandle::into_stream) — incremental async `Stream` /// - [`assemble()`](WsStreamHandle::assemble) — blocking wait for the complete payload pub struct WsStreamHandle { content: StreamContent, total_bytes: u64, - chunk_rx: mpsc::UnboundedReceiver>, + chunk_rx: mpsc::UnboundedReceiver, } impl WsStreamHandle { @@ -256,11 +270,11 @@ mod app_stream { /// Async stream of chunk data produced by [`WsStreamHandle::into_stream()`]. pub struct WsStream { - chunk_rx: mpsc::UnboundedReceiver>, + chunk_rx: mpsc::UnboundedReceiver, } impl futures::Stream for WsStream { - type Item = Vec; + type Item = Bytes; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.chunk_rx.poll_recv(cx) @@ -269,12 +283,12 @@ mod app_stream { /// Sender side — held by the request handler to feed chunks into the handle. pub struct WsStreamSender { - chunk_tx: mpsc::UnboundedSender>, + chunk_tx: mpsc::UnboundedSender, } impl WsStreamSender { /// Send a chunk of data to the corresponding [`WsStreamHandle`]. - pub fn send_chunk(&self, data: Vec) -> Result<(), StreamError> { + pub fn send_chunk(&self, data: Bytes) -> Result<(), StreamError> { self.chunk_tx .send(data) .map_err(|_| StreamError::ChannelClosed) @@ -339,7 +353,7 @@ mod tests { } = chunk { if let Some(result) = buf - .receive_chunk(*stream_id, *index, *total, chunk_data) + .receive_chunk(*stream_id, *index, *total, chunk_data.clone()) .unwrap() { assert_eq!(result, data); @@ -364,7 +378,7 @@ mod tests { } = chunk { if let Some(result) = buf - .receive_chunk(*stream_id, *index, *total, chunk_data) + .receive_chunk(*stream_id, *index, *total, chunk_data.clone()) .unwrap() { assert_eq!(result, data); @@ -401,7 +415,10 @@ mod tests { data, } = chunk { - if let Some(r) = buf.receive_chunk(*stream_id, *index, *total, data).unwrap() { + if let Some(r) = buf + .receive_chunk(*stream_id, *index, *total, data.clone()) + .unwrap() + { assert_eq!(r, data_a); } } @@ -415,7 +432,10 @@ mod tests { data, } = chunk { - if let Some(r) = buf.receive_chunk(*stream_id, *index, *total, data).unwrap() { + if let Some(r) = buf + .receive_chunk(*stream_id, *index, *total, data.clone()) + .unwrap() + { assert_eq!(r, data_b); } } @@ -425,30 +445,40 @@ mod tests { #[test] fn zero_total_chunks_error() { let mut buf = ReassemblyBuffer::new(); - let err = buf.receive_chunk(1, 0, 0, &[1, 2, 3]).unwrap_err(); + let err = buf + .receive_chunk(1, 0, 0, Bytes::from_static(&[1, 2, 3])) + .unwrap_err(); assert!(matches!(err, StreamError::ZeroTotalChunks)); } #[test] fn total_too_large_error() { let mut buf = ReassemblyBuffer::new(); - let err = buf.receive_chunk(1, 0, 1000, &[1, 2, 3]).unwrap_err(); + let err = buf + .receive_chunk(1, 0, 1000, Bytes::from_static(&[1, 2, 3])) + .unwrap_err(); assert!(matches!(err, StreamError::TotalChunksTooLarge { .. })); } #[test] fn total_mismatch_error() { let mut buf = ReassemblyBuffer::new(); - buf.receive_chunk(1, 0, 3, &[1, 2, 3]).unwrap(); - let err = buf.receive_chunk(1, 1, 5, &[4, 5, 6]).unwrap_err(); + buf.receive_chunk(1, 0, 3, Bytes::from_static(&[1, 2, 3])) + .unwrap(); + let err = buf + .receive_chunk(1, 1, 5, Bytes::from_static(&[4, 5, 6])) + .unwrap_err(); assert!(matches!(err, StreamError::TotalChunksMismatch { .. })); } #[test] fn duplicate_chunk_error() { let mut buf = ReassemblyBuffer::new(); - buf.receive_chunk(1, 0, 3, &[1, 2, 3]).unwrap(); - let err = buf.receive_chunk(1, 0, 3, &[4, 5, 6]).unwrap_err(); + buf.receive_chunk(1, 0, 3, Bytes::from_static(&[1, 2, 3])) + .unwrap(); + let err = buf + .receive_chunk(1, 0, 3, Bytes::from_static(&[4, 5, 6])) + .unwrap_err(); assert!(matches!( err, StreamError::DuplicateChunk { @@ -461,7 +491,9 @@ mod tests { #[test] fn index_out_of_range_error() { let mut buf = ReassemblyBuffer::new(); - let err = buf.receive_chunk(1, 5, 3, &[1, 2, 3]).unwrap_err(); + let err = buf + .receive_chunk(1, 5, 3, Bytes::from_static(&[1, 2, 3])) + .unwrap_err(); assert!(matches!( err, StreamError::IndexOutOfRange { @@ -476,10 +508,16 @@ mod tests { fn too_many_concurrent_streams_error() { let mut buf = ReassemblyBuffer::new(); for i in 0..MAX_CONCURRENT_STREAMS as u32 { - buf.receive_chunk(i, 0, 2, &[1]).unwrap(); + buf.receive_chunk(i, 0, 2, Bytes::from_static(&[1])) + .unwrap(); } let err = buf - .receive_chunk(MAX_CONCURRENT_STREAMS as u32, 0, 2, &[1]) + .receive_chunk( + MAX_CONCURRENT_STREAMS as u32, + 0, + 2, + Bytes::from_static(&[1]), + ) .unwrap_err(); assert!(matches!(err, StreamError::TooManyConcurrentStreams { .. })); } @@ -500,31 +538,32 @@ mod tests { key, includes_contract: false, }; - let data = vec![0xAB; CHUNK_SIZE * 3]; + let data = Bytes::from(vec![0xAB; CHUNK_SIZE * 3]); let (handle, sender) = ws_stream_pair(content, data.len() as u64); // Feed chunks in a background task let data_clone = data.clone(); tokio::spawn(async move { for chunk in data_clone.chunks(CHUNK_SIZE) { - sender.send_chunk(chunk.to_vec()).unwrap(); + sender.send_chunk(Bytes::copy_from_slice(chunk)).unwrap(); } // sender dropped here → handle's rx will close }); let result = handle.assemble().await.unwrap(); - assert_eq!(result, data); + assert_eq!(result, &data[..]); } #[tokio::test] async fn ws_stream_incremental() { let content = StreamContent::Raw; - let data = vec![0xCD; CHUNK_SIZE * 2]; + let data = Bytes::from(vec![0xCD; CHUNK_SIZE * 2]); let (handle, sender) = ws_stream_pair(content, data.len() as u64); + let data_clone = data.clone(); tokio::spawn(async move { - for chunk in data.chunks(CHUNK_SIZE) { - sender.send_chunk(chunk.to_vec()).unwrap(); + for chunk in data_clone.chunks(CHUNK_SIZE) { + sender.send_chunk(Bytes::copy_from_slice(chunk)).unwrap(); } }); From 9aa703dccf377306e3c02819987d953adb00e7e9 Mon Sep 17 00:00:00 2001 From: Hector Santos Date: Thu, 5 Mar 2026 08:41:30 +0100 Subject: [PATCH 4/5] streaming improvements --- rust/src/client_api/browser.rs | 5 +- rust/src/client_api/regular.rs | 137 +++++--- rust/src/client_api/streaming.rs | 133 +++++++- rust/src/generated/host_response_generated.rs | 320 +++++++++--------- schemas/flatbuffers/host_response.fbs | 2 +- 5 files changed, 388 insertions(+), 209 deletions(-) diff --git a/rust/src/client_api/browser.rs b/rust/src/client_api/browser.rs index 320ac67..6579e6a 100644 --- a/rust/src/client_api/browser.rs +++ b/rust/src/client_api/browser.rs @@ -11,6 +11,7 @@ type Connection = web_sys::WebSocket; pub struct WebApi { conn: Connection, error_handler: Box, + next_stream_id: u32, } impl Drop for WebApi { @@ -163,6 +164,7 @@ impl WebApi { WebApi { conn, error_handler: Box::new(error_handler), + next_stream_id: 0, } } @@ -192,7 +194,8 @@ impl WebApi { let send = bincode::serialize(&request)?; if send.len() > CHUNK_THRESHOLD { - let stream_id = 0; // browser client uses single stream + let stream_id = self.next_stream_id; + self.next_stream_id = self.next_stream_id.wrapping_add(1); let chunks = chunk_request(send, stream_id); for chunk in &chunks { let chunk_bytes = diff --git a/rust/src/client_api/regular.rs b/rust/src/client_api/regular.rs index 88d7ea4..b848f2c 100644 --- a/rust/src/client_api/regular.rs +++ b/rust/src/client_api/regular.rs @@ -1,4 +1,4 @@ -use std::{borrow::Cow, collections::HashMap, task::Poll}; +use std::{borrow::Cow, collections::HashMap, collections::VecDeque, future::Future, task::Poll}; use super::{ client_events::{ClientError, ClientRequest, ErrorKind, HostResponse}, @@ -25,6 +25,7 @@ pub struct WebApi { response_rx: Receiver, stream_rx: Receiver, queue: Vec>, + pending_streams: VecDeque + Send>>>, } impl Drop for WebApi { @@ -43,7 +44,40 @@ impl Stream for WebApi { mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { - self.response_rx.poll_recv(cx) + // First, try to complete any pending stream assemblies. + if let Some(fut) = self.pending_streams.front_mut() { + if let Poll::Ready(result) = fut.as_mut().poll(cx) { + self.pending_streams.pop_front(); + return Poll::Ready(Some(result)); + } + } + + // Poll regular responses. + match self.response_rx.poll_recv(cx) { + Poll::Ready(Some(result)) => return Poll::Ready(Some(result)), + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => {} + } + + // Poll stream handles and spawn assembly as a pending future. + match self.stream_rx.poll_recv(cx) { + Poll::Ready(Some(handle)) => { + let fut = Box::pin(async move { + let complete = handle + .assemble() + .await + .map_err(|e| ClientError::from(format!("{e}")))?; + let inner: HostResult = bincode::deserialize(&complete) + .map_err(|e| ClientError::from(format!("{e}")))?; + inner + }); + self.pending_streams.push_back(fut); + cx.waker().wake_by_ref(); + Poll::Pending + } + Poll::Ready(None) if self.pending_streams.is_empty() => Poll::Ready(None), + _ => Poll::Pending, + } } } @@ -108,7 +142,7 @@ impl WebApi { pub fn start(connection: Connection) -> Self { let (request_tx, request_rx) = mpsc::channel(1); let (response_tx, response_rx) = mpsc::channel(1); - let (stream_tx, stream_rx) = mpsc::channel(4); + let (stream_tx, stream_rx) = mpsc::channel(super::streaming::MAX_CONCURRENT_STREAMS); tokio::spawn(request_handler( request_rx, response_tx, @@ -120,6 +154,7 @@ impl WebApi { response_rx, stream_rx, queue: vec![], + pending_streams: VecDeque::new(), } } @@ -140,6 +175,13 @@ impl WebApi { /// complete [`HostResponse`] — the caller does not need to handle streaming. /// /// For incremental consumption, use [`recv_stream()`](Self::recv_stream) instead. + /// + /// # Important + /// + /// `recv()` and [`recv_stream()`](Self::recv_stream) both consume from the + /// internal stream channel. Calling both concurrently or alternating between + /// them may cause responses to be delivered to the wrong consumer. Choose + /// one consumption pattern per `WebApi` instance. pub async fn recv(&mut self) -> HostResult { tokio::select! { res = self.response_rx.recv() => { @@ -166,6 +208,11 @@ impl WebApi { /// /// Only returns when the server sends a `StreamHeader`; non-streamed /// responses are delivered through [`recv()`](Self::recv). + /// + /// # Important + /// + /// `recv_stream()` and [`recv()`](Self::recv) both consume from the internal + /// stream channel. See [`recv()`](Self::recv) for details. pub async fn recv_stream(&mut self) -> Result { self.stream_rx.recv().await.ok_or(Error::ChannelClosed) } @@ -316,15 +363,24 @@ async fn handle_response_payload( tracing::warn!("too many open stream senders, evicting one"); if let Some(&id) = stream_senders.keys().next() { stream_senders.remove(&id); + reassembly.remove_stream(id); } } let (handle, sender) = super::streaming::ws_stream_pair(content, total_bytes); stream_senders.insert(stream_id, sender); - stream_tx - .send(handle) - .await - .map_err(|_| Error::ChannelClosed)?; - Ok(()) + match stream_tx.try_send(handle) { + Ok(()) => Ok(()), + Err(mpsc::error::TrySendError::Full(_)) => { + tracing::warn!( + stream_id, + "stream_tx full, falling back to transparent reassembly" + ); + // Remove sender so subsequent chunks go through ReassemblyBuffer + stream_senders.remove(&stream_id); + Ok(()) + } + Err(mpsc::error::TrySendError::Closed(_)) => Err(Error::ChannelClosed), + } } Ok(HostResponse::StreamChunk { stream_id, @@ -378,10 +434,17 @@ mod test { use crate::client_api::HostResponse; use super::*; - use std::{net::Ipv4Addr, sync::atomic::AtomicU16, time::Duration}; + use std::{net::Ipv4Addr, time::Duration}; use tokio::net::TcpListener; - static PORT: AtomicU16 = AtomicU16::new(65495); + /// Bind to an OS-assigned port and return the listener + port. + async fn bind_free_port() -> (TcpListener, u16) { + let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0u16)) + .await + .unwrap(); + let port = listener.local_addr().unwrap().port(); + (listener, port) + } struct Server { recv: bool, @@ -389,10 +452,7 @@ mod test { } impl Server { - async fn new(port: u16, recv: bool) -> Self { - let listener = tokio::net::TcpListener::bind((Ipv4Addr::LOCALHOST, port)) - .await - .unwrap(); + async fn new(listener: TcpListener, recv: bool) -> Self { Server { recv, listener } } @@ -401,7 +461,7 @@ mod test { tx: tokio::sync::oneshot::Sender<()>, ) -> Result<(), Box> { let (stream, _) = - tokio::time::timeout(Duration::from_millis(10), self.listener.accept()).await??; + tokio::time::timeout(Duration::from_secs(5), self.listener.accept()).await??; let mut stream = tokio_tungstenite::accept_async(stream).await?; if !self.recv { @@ -453,7 +513,7 @@ mod test { use crate::client_api::streaming; let (tcp_stream, _) = - tokio::time::timeout(Duration::from_millis(100), listener.accept()).await??; + tokio::time::timeout(Duration::from_secs(5), listener.accept()).await??; let mut stream = tokio_tungstenite::accept_async(tcp_stream).await?; let (serialized, key) = build_test_payload(payload_size, fill); @@ -481,7 +541,7 @@ mod test { stream.send(Message::Binary(chunk_bytes.into())).await?; } - let msg = tokio::time::timeout(Duration::from_millis(100), stream.next()).await; + let msg = tokio::time::timeout(Duration::from_secs(5), stream.next()).await; drop(msg); tx.send(()).map_err(|_| "signal failed".to_owned())?; Ok(()) @@ -491,11 +551,8 @@ mod test { async fn test_recv_chunked() -> Result<(), Box> { use crate::client_api::ContractResponse; - let port = PORT.fetch_add(1, std::sync::atomic::Ordering::SeqCst); let payload_size = 600 * 1024; - let listener = tokio::net::TcpListener::bind((Ipv4Addr::LOCALHOST, port)) - .await - .unwrap(); + let (listener, port) = bind_free_port().await; let (tx, rx) = tokio::sync::oneshot::channel::<()>(); let server_result = tokio::task::spawn(serve_chunked_response( listener, @@ -520,8 +577,8 @@ mod test { client .send(ClientRequest::Disconnect { cause: None }) .await?; - tokio::time::timeout(Duration::from_millis(100), rx).await??; - tokio::time::timeout(Duration::from_millis(100), server_result).await???; + tokio::time::timeout(Duration::from_secs(5), rx).await??; + tokio::time::timeout(Duration::from_secs(5), server_result).await???; Ok(()) } @@ -529,11 +586,8 @@ mod test { async fn test_recv_stream_header() -> Result<(), Box> { use crate::client_api::ContractResponse; - let port = PORT.fetch_add(1, std::sync::atomic::Ordering::SeqCst); let payload_size = 600 * 1024; - let listener = tokio::net::TcpListener::bind((Ipv4Addr::LOCALHOST, port)) - .await - .unwrap(); + let (listener, port) = bind_free_port().await; let (tx, rx) = tokio::sync::oneshot::channel::<()>(); let server_result = tokio::task::spawn(serve_chunked_response( listener, @@ -564,8 +618,8 @@ mod test { client .send(ClientRequest::Disconnect { cause: None }) .await?; - tokio::time::timeout(Duration::from_millis(100), rx).await??; - tokio::time::timeout(Duration::from_millis(100), server_result).await???; + tokio::time::timeout(Duration::from_secs(5), rx).await??; + tokio::time::timeout(Duration::from_secs(5), server_result).await???; Ok(()) } @@ -575,11 +629,8 @@ mod test { ) -> Result<(), Box> { use crate::client_api::ContractResponse; - let port = PORT.fetch_add(1, std::sync::atomic::Ordering::SeqCst); let payload_size = 600 * 1024; - let listener = tokio::net::TcpListener::bind((Ipv4Addr::LOCALHOST, port)) - .await - .unwrap(); + let (listener, port) = bind_free_port().await; let (tx, rx) = tokio::sync::oneshot::channel::<()>(); let server_result = tokio::task::spawn(serve_chunked_response( listener, @@ -605,15 +656,15 @@ mod test { client .send(ClientRequest::Disconnect { cause: None }) .await?; - tokio::time::timeout(Duration::from_millis(100), rx).await??; - tokio::time::timeout(Duration::from_millis(100), server_result).await???; + tokio::time::timeout(Duration::from_secs(5), rx).await??; + tokio::time::timeout(Duration::from_secs(5), server_result).await???; Ok(()) } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_send() -> Result<(), Box> { - let port = PORT.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - let server = Server::new(port, true).await; + let (listener, port) = bind_free_port().await; + let server = Server::new(listener, true).await; let (tx, rx) = tokio::sync::oneshot::channel::<()>(); let server_result = tokio::task::spawn(server.listen(tx)); let (ws_conn, _) = @@ -623,15 +674,15 @@ mod test { client .send(ClientRequest::Disconnect { cause: None }) .await?; - tokio::time::timeout(Duration::from_millis(10), rx).await??; - tokio::time::timeout(Duration::from_millis(10), server_result).await???; + tokio::time::timeout(Duration::from_secs(5), rx).await??; + tokio::time::timeout(Duration::from_secs(5), server_result).await???; Ok(()) } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_recv() -> Result<(), Box> { - let port = PORT.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - let server = Server::new(port, false).await; + let (listener, port) = bind_free_port().await; + let server = Server::new(listener, false).await; let (tx, rx) = tokio::sync::oneshot::channel::<()>(); let server_result = tokio::task::spawn(server.listen(tx)); let (ws_conn, _) = @@ -642,8 +693,8 @@ mod test { client .send(ClientRequest::Disconnect { cause: None }) .await?; - tokio::time::timeout(Duration::from_millis(10), rx).await??; - tokio::time::timeout(Duration::from_millis(10), server_result).await???; + tokio::time::timeout(Duration::from_secs(5), rx).await??; + tokio::time::timeout(Duration::from_secs(5), server_result).await???; Ok(()) } } diff --git a/rust/src/client_api/streaming.rs b/rust/src/client_api/streaming.rs index 554d021..5602fab 100644 --- a/rust/src/client_api/streaming.rs +++ b/rust/src/client_api/streaming.rs @@ -24,9 +24,6 @@ pub const MAX_TOTAL_CHUNKS: u32 = 256; /// Maximum concurrent streams in a single `ReassemblyBuffer`. pub const MAX_CONCURRENT_STREAMS: usize = 8; -/// Chunks to send before yielding to the tokio runtime. -pub const MAX_CHUNKS_PER_BATCH: usize = 32; - /// Zero-copy chunking: split `data` into (index, total, slice) tuples using `Bytes::slice()`. fn chunk_bytes(data: &Bytes) -> Vec<(u32, u32, Bytes)> { let total = data.len().div_ceil(CHUNK_SIZE).max(1) as u32; @@ -102,12 +99,20 @@ pub enum StreamError { ChannelClosed, #[error("stream truncated: received {received} of {expected} bytes")] Truncated { received: u64, expected: u64 }, + #[error("stream overflow: received {received} bytes, expected at most {expected} bytes")] + Overflow { received: u64, expected: u64 }, } +/// Timeout for incomplete streams in the reassembly buffer. +#[cfg(not(target_family = "wasm"))] +const STREAM_TTL: std::time::Duration = std::time::Duration::from_secs(60); + struct StreamState { chunks: Vec>, total: u32, received: u32, + #[cfg(not(target_family = "wasm"))] + created_at: std::time::Instant, } /// Reassembly buffer keyed by stream ID. Supports concurrent streams. @@ -148,6 +153,10 @@ impl ReassemblyBuffer { }); } + // Evict stale entries before checking the concurrent limit. + #[cfg(not(target_family = "wasm"))] + self.evict_stale(); + // Reject new streams when the concurrent stream limit is reached. if !self.streams.contains_key(&stream_id) && self.streams.len() >= MAX_CONCURRENT_STREAMS { return Err(StreamError::TooManyConcurrentStreams { @@ -163,6 +172,8 @@ impl ReassemblyBuffer { chunks: vec![None; total as usize], total, received: 0, + #[cfg(not(target_family = "wasm"))] + created_at: std::time::Instant::now(), }); if state.total != total { @@ -193,6 +204,18 @@ impl ReassemblyBuffer { Ok(None) } } + + /// Remove a stream by ID, returning `true` if it existed. + pub fn remove_stream(&mut self, stream_id: u32) -> bool { + self.streams.remove(&stream_id).is_some() + } + + #[cfg(not(target_family = "wasm"))] + fn evict_stale(&mut self) { + let now = std::time::Instant::now(); + self.streams + .retain(|_id, state| now.duration_since(state.created_at) < STREAM_TTL); + } } impl Default for ReassemblyBuffer { @@ -250,12 +273,21 @@ mod app_stream { /// Wait for all chunks and return the fully reassembled payload. /// /// Returns [`StreamError::Truncated`] if the sender closes before all - /// expected bytes have been delivered. + /// expected bytes have been delivered, or [`StreamError::Overflow`] if + /// more data is received than the header promised. pub async fn assemble(mut self) -> Result, StreamError> { // Cap pre-allocation to avoid OOM from a malicious total_bytes header. const MAX_PREALLOC: usize = 50 * 1024 * 1024; + // Allow up to one extra chunk of slack beyond total_bytes. + let max_bytes = (self.total_bytes as usize).saturating_add(super::CHUNK_SIZE); let mut buf = Vec::with_capacity((self.total_bytes as usize).min(MAX_PREALLOC)); while let Some(chunk) = self.chunk_rx.recv().await { + if buf.len().saturating_add(chunk.len()) > max_bytes { + return Err(StreamError::Overflow { + received: buf.len() as u64 + chunk.len() as u64, + expected: self.total_bytes, + }); + } buf.extend_from_slice(&chunk); } if (buf.len() as u64) < self.total_bytes { @@ -522,6 +554,66 @@ mod tests { assert!(matches!(err, StreamError::TooManyConcurrentStreams { .. })); } + #[test] + fn reassembly_out_of_order() { + let data: Vec = (0..CHUNK_SIZE * 3).map(|i| (i % 256) as u8).collect(); + let chunks = chunk_request(data.clone(), 10); + assert_eq!(chunks.len(), 3); + + let mut buf = ReassemblyBuffer::new(); + // Feed in reverse order: 2, 0, 1 + let indices = [2, 0, 1]; + let mut result = None; + for &i in &indices { + if let ClientRequest::StreamChunk { + stream_id, + index, + total, + data: chunk_data, + } = &chunks[i] + { + if let Some(r) = buf + .receive_chunk(*stream_id, *index, *total, chunk_data.clone()) + .unwrap() + { + result = Some(r); + } + } + } + assert_eq!(result.unwrap(), data); + } + + #[test] + fn chunk_exact_boundary() { + // Exactly one chunk + let data = vec![0xEE; CHUNK_SIZE]; + let chunks = chunk_request(data, 5); + assert_eq!(chunks.len(), 1); + + // Exactly two chunks + let data2 = vec![0xEE; CHUNK_SIZE * 2]; + let chunks2 = chunk_request(data2, 6); + assert_eq!(chunks2.len(), 2); + + // One byte over two chunks + let data3 = vec![0xEE; CHUNK_SIZE * 2 + 1]; + let chunks3 = chunk_request(data3, 7); + assert_eq!(chunks3.len(), 3); + } + + #[test] + fn remove_stream_cleans_up() { + let mut buf = ReassemblyBuffer::new(); + buf.receive_chunk(1, 0, 3, Bytes::from_static(&[1, 2, 3])) + .unwrap(); + assert!(buf.remove_stream(1)); + assert!(!buf.remove_stream(1)); // already removed + + // Can start a new stream with the same id + buf.receive_chunk(1, 0, 2, Bytes::from_static(&[4, 5])) + .unwrap(); + } + #[cfg(all(feature = "net", not(target_family = "wasm")))] mod stream_handle_tests { use super::super::*; @@ -575,5 +667,38 @@ mod tests { assert_eq!(collected.len(), CHUNK_SIZE * 2); assert!(collected.iter().all(|&b| b == 0xCD)); } + + #[tokio::test] + async fn ws_stream_assemble_truncated() { + let content = StreamContent::Raw; + let (handle, sender) = ws_stream_pair(content, 1000); + // Send less than promised, then drop sender + sender.send_chunk(Bytes::from(vec![0; 100])).unwrap(); + drop(sender); + let err = handle.assemble().await.unwrap_err(); + assert!(matches!( + err, + StreamError::Truncated { + received: 100, + expected: 1000 + } + )); + } + + #[tokio::test] + async fn ws_stream_assemble_overflow() { + let content = StreamContent::Raw; + // Claim only 10 bytes + let (handle, sender) = ws_stream_pair(content, 10); + // Send way more than promised (over total_bytes + CHUNK_SIZE) + let overflow_size = 10 + CHUNK_SIZE + 1; + tokio::spawn(async move { + sender + .send_chunk(Bytes::from(vec![0xFF; overflow_size])) + .unwrap(); + }); + let err = handle.assemble().await.unwrap_err(); + assert!(matches!(err, StreamError::Overflow { .. })); + } } } diff --git a/rust/src/generated/host_response_generated.rs b/rust/src/generated/host_response_generated.rs index 91ce18d..c7e535f 100644 --- a/rust/src/generated/host_response_generated.rs +++ b/rust/src/generated/host_response_generated.rs @@ -2714,15 +2714,15 @@ pub mod host_response { ds.finish() } } - pub enum OkOffset {} + pub enum StreamChunkOffset {} #[derive(Copy, Clone, PartialEq)] - pub struct Ok<'a> { + pub struct StreamChunk<'a> { pub _tab: flatbuffers::Table<'a>, } - impl<'a> flatbuffers::Follow<'a> for Ok<'a> { - type Inner = Ok<'a>; + impl<'a> flatbuffers::Follow<'a> for StreamChunk<'a> { + type Inner = StreamChunk<'a>; #[inline] unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { Self { @@ -2731,12 +2731,15 @@ pub mod host_response { } } - impl<'a> Ok<'a> { - pub const VT_MSG: flatbuffers::VOffsetT = 4; + impl<'a> StreamChunk<'a> { + pub const VT_STREAM_ID: flatbuffers::VOffsetT = 4; + pub const VT_INDEX: flatbuffers::VOffsetT = 6; + pub const VT_TOTAL: flatbuffers::VOffsetT = 8; + pub const VT_DATA: flatbuffers::VOffsetT = 10; #[inline] pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { - Ok { _tab: table } + StreamChunk { _tab: table } } #[allow(unused_mut)] pub fn create< @@ -2746,29 +2749,68 @@ pub mod host_response { A: flatbuffers::Allocator + 'bldr, >( _fbb: &'mut_bldr mut flatbuffers::FlatBufferBuilder<'bldr, A>, - args: &'args OkArgs<'args>, - ) -> flatbuffers::WIPOffset> { - let mut builder = OkBuilder::new(_fbb); - if let Some(x) = args.msg { - builder.add_msg(x); + args: &'args StreamChunkArgs<'args>, + ) -> flatbuffers::WIPOffset> { + let mut builder = StreamChunkBuilder::new(_fbb); + if let Some(x) = args.data { + builder.add_data(x); } + builder.add_total(args.total); + builder.add_index(args.index); + builder.add_stream_id(args.stream_id); builder.finish() } #[inline] - pub fn msg(&self) -> &'a str { + pub fn stream_id(&self) -> u32 { // Safety: // Created from valid Table for this object // which contains a valid value in this slot unsafe { self._tab - .get::>(Ok::VT_MSG, None) + .get::(StreamChunk::VT_STREAM_ID, Some(0)) + .unwrap() + } + } + #[inline] + pub fn index(&self) -> u32 { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::(StreamChunk::VT_INDEX, Some(0)) + .unwrap() + } + } + #[inline] + pub fn total(&self) -> u32 { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::(StreamChunk::VT_TOTAL, Some(0)) + .unwrap() + } + } + #[inline] + pub fn data(&self) -> flatbuffers::Vector<'a, u8> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>>( + StreamChunk::VT_DATA, + None, + ) .unwrap() } } } - impl flatbuffers::Verifiable for Ok<'_> { + impl flatbuffers::Verifiable for StreamChunk<'_> { #[inline] fn run_verifier( v: &mut flatbuffers::Verifier, @@ -2776,65 +2818,96 @@ pub mod host_response { ) -> Result<(), flatbuffers::InvalidFlatbuffer> { use self::flatbuffers::Verifiable; v.visit_table(pos)? - .visit_field::>("msg", Self::VT_MSG, true)? + .visit_field::("stream_id", Self::VT_STREAM_ID, false)? + .visit_field::("index", Self::VT_INDEX, false)? + .visit_field::("total", Self::VT_TOTAL, false)? + .visit_field::>>( + "data", + Self::VT_DATA, + true, + )? .finish(); Ok(()) } } - pub struct OkArgs<'a> { - pub msg: Option>, + pub struct StreamChunkArgs<'a> { + pub stream_id: u32, + pub index: u32, + pub total: u32, + pub data: Option>>, } - impl<'a> Default for OkArgs<'a> { + impl<'a> Default for StreamChunkArgs<'a> { #[inline] fn default() -> Self { - OkArgs { - msg: None, // required field + StreamChunkArgs { + stream_id: 0, + index: 0, + total: 0, + data: None, // required field } } } - pub struct OkBuilder<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> { + pub struct StreamChunkBuilder<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> { fbb_: &'b mut flatbuffers::FlatBufferBuilder<'a, A>, start_: flatbuffers::WIPOffset, } - impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> OkBuilder<'a, 'b, A> { + impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> StreamChunkBuilder<'a, 'b, A> { #[inline] - pub fn add_msg(&mut self, msg: flatbuffers::WIPOffset<&'b str>) { + pub fn add_stream_id(&mut self, stream_id: u32) { self.fbb_ - .push_slot_always::>(Ok::VT_MSG, msg); + .push_slot::(StreamChunk::VT_STREAM_ID, stream_id, 0); } #[inline] - pub fn new(_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>) -> OkBuilder<'a, 'b, A> { + pub fn add_index(&mut self, index: u32) { + self.fbb_.push_slot::(StreamChunk::VT_INDEX, index, 0); + } + #[inline] + pub fn add_total(&mut self, total: u32) { + self.fbb_.push_slot::(StreamChunk::VT_TOTAL, total, 0); + } + #[inline] + pub fn add_data(&mut self, data: flatbuffers::WIPOffset>) { + self.fbb_ + .push_slot_always::>(StreamChunk::VT_DATA, data); + } + #[inline] + pub fn new( + _fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>, + ) -> StreamChunkBuilder<'a, 'b, A> { let start = _fbb.start_table(); - OkBuilder { + StreamChunkBuilder { fbb_: _fbb, start_: start, } } #[inline] - pub fn finish(self) -> flatbuffers::WIPOffset> { + pub fn finish(self) -> flatbuffers::WIPOffset> { let o = self.fbb_.end_table(self.start_); - self.fbb_.required(o, Ok::VT_MSG, "msg"); + self.fbb_.required(o, StreamChunk::VT_DATA, "data"); flatbuffers::WIPOffset::new(o.value()) } } - impl core::fmt::Debug for Ok<'_> { + impl core::fmt::Debug for StreamChunk<'_> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let mut ds = f.debug_struct("Ok"); - ds.field("msg", &self.msg()); + let mut ds = f.debug_struct("StreamChunk"); + ds.field("stream_id", &self.stream_id()); + ds.field("index", &self.index()); + ds.field("total", &self.total()); + ds.field("data", &self.data()); ds.finish() } } - pub enum ErrorOffset {} + pub enum OkOffset {} #[derive(Copy, Clone, PartialEq)] - pub struct Error<'a> { + pub struct Ok<'a> { pub _tab: flatbuffers::Table<'a>, } - impl<'a> flatbuffers::Follow<'a> for Error<'a> { - type Inner = Error<'a>; + impl<'a> flatbuffers::Follow<'a> for Ok<'a> { + type Inner = Ok<'a>; #[inline] unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { Self { @@ -2843,12 +2916,12 @@ pub mod host_response { } } - impl<'a> Error<'a> { + impl<'a> Ok<'a> { pub const VT_MSG: flatbuffers::VOffsetT = 4; #[inline] pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { - Error { _tab: table } + Ok { _tab: table } } #[allow(unused_mut)] pub fn create< @@ -2858,9 +2931,9 @@ pub mod host_response { A: flatbuffers::Allocator + 'bldr, >( _fbb: &'mut_bldr mut flatbuffers::FlatBufferBuilder<'bldr, A>, - args: &'args ErrorArgs<'args>, - ) -> flatbuffers::WIPOffset> { - let mut builder = ErrorBuilder::new(_fbb); + args: &'args OkArgs<'args>, + ) -> flatbuffers::WIPOffset> { + let mut builder = OkBuilder::new(_fbb); if let Some(x) = args.msg { builder.add_msg(x); } @@ -2874,13 +2947,13 @@ pub mod host_response { // which contains a valid value in this slot unsafe { self._tab - .get::>(Error::VT_MSG, None) + .get::>(Ok::VT_MSG, None) .unwrap() } } } - impl flatbuffers::Verifiable for Error<'_> { + impl flatbuffers::Verifiable for Ok<'_> { #[inline] fn run_verifier( v: &mut flatbuffers::Verifier, @@ -2893,60 +2966,60 @@ pub mod host_response { Ok(()) } } - pub struct ErrorArgs<'a> { + pub struct OkArgs<'a> { pub msg: Option>, } - impl<'a> Default for ErrorArgs<'a> { + impl<'a> Default for OkArgs<'a> { #[inline] fn default() -> Self { - ErrorArgs { + OkArgs { msg: None, // required field } } } - pub struct ErrorBuilder<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> { + pub struct OkBuilder<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> { fbb_: &'b mut flatbuffers::FlatBufferBuilder<'a, A>, start_: flatbuffers::WIPOffset, } - impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> ErrorBuilder<'a, 'b, A> { + impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> OkBuilder<'a, 'b, A> { #[inline] pub fn add_msg(&mut self, msg: flatbuffers::WIPOffset<&'b str>) { self.fbb_ - .push_slot_always::>(Error::VT_MSG, msg); + .push_slot_always::>(Ok::VT_MSG, msg); } #[inline] - pub fn new(_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>) -> ErrorBuilder<'a, 'b, A> { + pub fn new(_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>) -> OkBuilder<'a, 'b, A> { let start = _fbb.start_table(); - ErrorBuilder { + OkBuilder { fbb_: _fbb, start_: start, } } #[inline] - pub fn finish(self) -> flatbuffers::WIPOffset> { + pub fn finish(self) -> flatbuffers::WIPOffset> { let o = self.fbb_.end_table(self.start_); - self.fbb_.required(o, Error::VT_MSG, "msg"); + self.fbb_.required(o, Ok::VT_MSG, "msg"); flatbuffers::WIPOffset::new(o.value()) } } - impl core::fmt::Debug for Error<'_> { + impl core::fmt::Debug for Ok<'_> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let mut ds = f.debug_struct("Error"); + let mut ds = f.debug_struct("Ok"); ds.field("msg", &self.msg()); ds.finish() } } - pub enum StreamChunkOffset {} + pub enum ErrorOffset {} #[derive(Copy, Clone, PartialEq)] - pub struct StreamChunk<'a> { + pub struct Error<'a> { pub _tab: flatbuffers::Table<'a>, } - impl<'a> flatbuffers::Follow<'a> for StreamChunk<'a> { - type Inner = StreamChunk<'a>; + impl<'a> flatbuffers::Follow<'a> for Error<'a> { + type Inner = Error<'a>; #[inline] unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { Self { @@ -2955,15 +3028,12 @@ pub mod host_response { } } - impl<'a> StreamChunk<'a> { - pub const VT_STREAM_ID: flatbuffers::VOffsetT = 4; - pub const VT_INDEX: flatbuffers::VOffsetT = 6; - pub const VT_TOTAL: flatbuffers::VOffsetT = 8; - pub const VT_DATA: flatbuffers::VOffsetT = 10; + impl<'a> Error<'a> { + pub const VT_MSG: flatbuffers::VOffsetT = 4; #[inline] pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { - StreamChunk { _tab: table } + Error { _tab: table } } #[allow(unused_mut)] pub fn create< @@ -2973,68 +3043,29 @@ pub mod host_response { A: flatbuffers::Allocator + 'bldr, >( _fbb: &'mut_bldr mut flatbuffers::FlatBufferBuilder<'bldr, A>, - args: &'args StreamChunkArgs<'args>, - ) -> flatbuffers::WIPOffset> { - let mut builder = StreamChunkBuilder::new(_fbb); - if let Some(x) = args.data { - builder.add_data(x); + args: &'args ErrorArgs<'args>, + ) -> flatbuffers::WIPOffset> { + let mut builder = ErrorBuilder::new(_fbb); + if let Some(x) = args.msg { + builder.add_msg(x); } - builder.add_total(args.total); - builder.add_index(args.index); - builder.add_stream_id(args.stream_id); builder.finish() } #[inline] - pub fn stream_id(&self) -> u32 { - // Safety: - // Created from valid Table for this object - // which contains a valid value in this slot - unsafe { - self._tab - .get::(StreamChunk::VT_STREAM_ID, Some(0)) - .unwrap() - } - } - #[inline] - pub fn index(&self) -> u32 { - // Safety: - // Created from valid Table for this object - // which contains a valid value in this slot - unsafe { - self._tab - .get::(StreamChunk::VT_INDEX, Some(0)) - .unwrap() - } - } - #[inline] - pub fn total(&self) -> u32 { - // Safety: - // Created from valid Table for this object - // which contains a valid value in this slot - unsafe { - self._tab - .get::(StreamChunk::VT_TOTAL, Some(0)) - .unwrap() - } - } - #[inline] - pub fn data(&self) -> flatbuffers::Vector<'a, u8> { + pub fn msg(&self) -> &'a str { // Safety: // Created from valid Table for this object // which contains a valid value in this slot unsafe { self._tab - .get::>>( - StreamChunk::VT_DATA, - None, - ) + .get::>(Error::VT_MSG, None) .unwrap() } } } - impl flatbuffers::Verifiable for StreamChunk<'_> { + impl flatbuffers::Verifiable for Error<'_> { #[inline] fn run_verifier( v: &mut flatbuffers::Verifier, @@ -3042,84 +3073,53 @@ pub mod host_response { ) -> Result<(), flatbuffers::InvalidFlatbuffer> { use self::flatbuffers::Verifiable; v.visit_table(pos)? - .visit_field::("stream_id", Self::VT_STREAM_ID, false)? - .visit_field::("index", Self::VT_INDEX, false)? - .visit_field::("total", Self::VT_TOTAL, false)? - .visit_field::>>( - "data", - Self::VT_DATA, - true, - )? + .visit_field::>("msg", Self::VT_MSG, true)? .finish(); Ok(()) } } - pub struct StreamChunkArgs<'a> { - pub stream_id: u32, - pub index: u32, - pub total: u32, - pub data: Option>>, + pub struct ErrorArgs<'a> { + pub msg: Option>, } - impl<'a> Default for StreamChunkArgs<'a> { + impl<'a> Default for ErrorArgs<'a> { #[inline] fn default() -> Self { - StreamChunkArgs { - stream_id: 0, - index: 0, - total: 0, - data: None, // required field + ErrorArgs { + msg: None, // required field } } } - pub struct StreamChunkBuilder<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> { + pub struct ErrorBuilder<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> { fbb_: &'b mut flatbuffers::FlatBufferBuilder<'a, A>, start_: flatbuffers::WIPOffset, } - impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> StreamChunkBuilder<'a, 'b, A> { - #[inline] - pub fn add_stream_id(&mut self, stream_id: u32) { - self.fbb_ - .push_slot::(StreamChunk::VT_STREAM_ID, stream_id, 0); - } - #[inline] - pub fn add_index(&mut self, index: u32) { - self.fbb_.push_slot::(StreamChunk::VT_INDEX, index, 0); - } - #[inline] - pub fn add_total(&mut self, total: u32) { - self.fbb_.push_slot::(StreamChunk::VT_TOTAL, total, 0); - } + impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> ErrorBuilder<'a, 'b, A> { #[inline] - pub fn add_data(&mut self, data: flatbuffers::WIPOffset>) { + pub fn add_msg(&mut self, msg: flatbuffers::WIPOffset<&'b str>) { self.fbb_ - .push_slot_always::>(StreamChunk::VT_DATA, data); + .push_slot_always::>(Error::VT_MSG, msg); } #[inline] - pub fn new( - _fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>, - ) -> StreamChunkBuilder<'a, 'b, A> { + pub fn new(_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>) -> ErrorBuilder<'a, 'b, A> { let start = _fbb.start_table(); - StreamChunkBuilder { + ErrorBuilder { fbb_: _fbb, start_: start, } } #[inline] - pub fn finish(self) -> flatbuffers::WIPOffset> { + pub fn finish(self) -> flatbuffers::WIPOffset> { let o = self.fbb_.end_table(self.start_); - self.fbb_.required(o, StreamChunk::VT_DATA, "data"); + self.fbb_.required(o, Error::VT_MSG, "msg"); flatbuffers::WIPOffset::new(o.value()) } } - impl core::fmt::Debug for StreamChunk<'_> { + impl core::fmt::Debug for Error<'_> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let mut ds = f.debug_struct("StreamChunk"); - ds.field("stream_id", &self.stream_id()); - ds.field("index", &self.index()); - ds.field("total", &self.total()); - ds.field("data", &self.data()); + let mut ds = f.debug_struct("Error"); + ds.field("msg", &self.msg()); ds.finish() } } diff --git a/schemas/flatbuffers/host_response.fbs b/schemas/flatbuffers/host_response.fbs index 8622eb6..3a336ce 100644 --- a/schemas/flatbuffers/host_response.fbs +++ b/schemas/flatbuffers/host_response.fbs @@ -104,9 +104,9 @@ union HostResponseType { ContractResponse, DelegateResponse, GenerateRandData, - StreamChunk Ok, Error, + StreamChunk, } table HostResponse { From 8df4e097e0a471e6f223e50774832e0fd84aa7a7 Mon Sep 17 00:00:00 2001 From: Hector Santos Date: Thu, 5 Mar 2026 19:10:44 +0100 Subject: [PATCH 5/5] improve and fix streaming bugs --- rust/src/client_api/browser.rs | 1 + rust/src/client_api/regular.rs | 57 ++++++++++++++------------------ rust/src/client_api/streaming.rs | 8 +++++ 3 files changed, 34 insertions(+), 32 deletions(-) diff --git a/rust/src/client_api/browser.rs b/rust/src/client_api/browser.rs index 6579e6a..4e30d48 100644 --- a/rust/src/client_api/browser.rs +++ b/rust/src/client_api/browser.rs @@ -104,6 +104,7 @@ impl WebApi { } Ok(None) => return, // more chunks needed Err(e) => { + reassembly_clone.borrow_mut().remove_stream(stream_id); eh_clone.borrow_mut()(Error::ConnectionError(serde_json::json!({ "error": format!("{e}"), "source": "streaming reassembly" diff --git a/rust/src/client_api/regular.rs b/rust/src/client_api/regular.rs index b848f2c..35ac71e 100644 --- a/rust/src/client_api/regular.rs +++ b/rust/src/client_api/regular.rs @@ -1,11 +1,13 @@ -use std::{borrow::Cow, collections::HashMap, collections::VecDeque, future::Future, task::Poll}; +use std::{ + borrow::Cow, collections::HashMap, collections::VecDeque, future::Future, pin::Pin, task::Poll, +}; use super::{ client_events::{ClientError, ClientRequest, ErrorKind, HostResponse}, streaming::WsStreamHandle, Error, HostResult, }; -use futures::{pin_mut, FutureExt, Sink, SinkExt, Stream, StreamExt}; +use futures::{stream::FuturesUnordered, Sink, SinkExt, Stream, StreamExt}; use tokio::{ net::TcpStream, sync::mpsc::{self, Receiver, Sender}, @@ -24,8 +26,8 @@ pub struct WebApi { request_tx: Sender>, response_rx: Receiver, stream_rx: Receiver, - queue: Vec>, - pending_streams: VecDeque + Send>>>, + queue: VecDeque>, + pending_streams: FuturesUnordered + Send>>>, } impl Drop for WebApi { @@ -44,12 +46,10 @@ impl Stream for WebApi { mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { - // First, try to complete any pending stream assemblies. - if let Some(fut) = self.pending_streams.front_mut() { - if let Poll::Ready(result) = fut.as_mut().poll(cx) { - self.pending_streams.pop_front(); - return Poll::Ready(Some(result)); - } + // Poll all pending stream assemblies concurrently. + match self.pending_streams.poll_next_unpin(cx) { + Poll::Ready(Some(result)) => return Poll::Ready(Some(result)), + Poll::Ready(None) | Poll::Pending => {} } // Poll regular responses. @@ -71,7 +71,7 @@ impl Stream for WebApi { .map_err(|e| ClientError::from(format!("{e}")))?; inner }); - self.pending_streams.push_back(fut); + self.pending_streams.push(fut); cx.waker().wake_by_ref(); Poll::Pending } @@ -99,7 +99,7 @@ impl Sink> for WebApi { mut self: std::pin::Pin<&mut Self>, item: ClientRequest<'static>, ) -> Result<(), Self::Error> { - self.queue.push(item); + self.queue.push_back(item); Ok(()) } @@ -107,27 +107,20 @@ impl Sink> for WebApi { mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { - let mut queue = vec![]; - std::mem::swap(&mut queue, &mut self.queue); - let mut error = false; - while let Some(item) = queue.pop() { - let f = self.request_tx.send(item); - pin_mut!(f); - match f.poll_unpin(cx) { - Poll::Ready(Ok(_)) => continue, - Poll::Ready(Err(_err)) => { - error = true; - break; + while let Some(item) = self.queue.pop_front() { + match self.request_tx.try_send(item) { + Ok(()) => continue, + Err(mpsc::error::TrySendError::Full(item)) => { + self.queue.push_front(item); + cx.waker().wake_by_ref(); + return Poll::Pending; + } + Err(mpsc::error::TrySendError::Closed(_)) => { + return Poll::Ready(Err(ErrorKind::ChannelClosed.into())); } - Poll::Pending => {} } } - if error { - self.queue.append(&mut queue); - Poll::Ready(Err(ErrorKind::ChannelClosed.into())) - } else { - Poll::Ready(Ok(())) - } + Poll::Ready(Ok(())) } fn poll_close( @@ -153,8 +146,8 @@ impl WebApi { request_tx, response_rx, stream_rx, - queue: vec![], - pending_streams: VecDeque::new(), + queue: VecDeque::new(), + pending_streams: FuturesUnordered::new(), } } diff --git a/rust/src/client_api/streaming.rs b/rust/src/client_api/streaming.rs index 5602fab..8db1df5 100644 --- a/rust/src/client_api/streaming.rs +++ b/rust/src/client_api/streaming.rs @@ -276,6 +276,14 @@ mod app_stream { /// expected bytes have been delivered, or [`StreamError::Overflow`] if /// more data is received than the header promised. pub async fn assemble(mut self) -> Result, StreamError> { + // Reject total_bytes exceeding the protocol maximum before allocating. + let protocol_max = super::MAX_TOTAL_CHUNKS as u64 * super::CHUNK_SIZE as u64; + if self.total_bytes > protocol_max { + return Err(StreamError::Overflow { + received: 0, + expected: protocol_max, + }); + } // Cap pre-allocation to avoid OOM from a malicious total_bytes header. const MAX_PREALLOC: usize = 50 * 1024 * 1024; // Allow up to one extra chunk of slack beyond total_bytes.