diff --git a/README.md b/README.md index fa5faf3..75cc51b 100644 --- a/README.md +++ b/README.md @@ -870,6 +870,37 @@ The `IAudioInference` class supports the following parameters: - `duration`: Duration of the generated audio in seconds - `includeCost`: Whether to include cost information in the response +### Text inference streaming + +To stream text inference (e.g. LLM chat) over HTTP SSE, set `deliveryMethod="stream"`. The SDK yields content chunks (strings) and a final `IText` with usage and cost: + +```python +import asyncio +from runware import Runware, ITextInference, ITextInferenceMessage + +async def main() -> None: + runware = Runware(api_key=RUNWARE_API_KEY) + await runware.connect() + + request = ITextInference( + model="runware:qwen3-thinking@1", + messages=[ITextInferenceMessage(role="user", content="Explain photosynthesis in one sentence.")], + deliveryMethod="stream", + includeCost=True, + ) + + stream = await runware.textInference(request) + async for chunk in stream: + if isinstance(chunk, str): + print(chunk, end="", flush=True) + else: + print(chunk) + +asyncio.run(main()) +``` + +Streaming uses the same concurrency limit as other requests (`RUNWARE_MAX_CONCURRENT_REQUESTS`). To allow longer streams, set `RUNWARE_TEXT_STREAM_TIMEOUT` (milliseconds; default 600000). + ### Model Upload To upload model using the Runware API, you can use the `uploadModel` method of the `Runware` class. Here are examples: @@ -1106,6 +1137,9 @@ RUNWARE_AUDIO_INFERENCE_TIMEOUT=300000 # Audio generation (default: 5 min) RUNWARE_AUDIO_POLLING_DELAY=1000 # Delay between status checks (default: 1 sec) RUNWARE_MAX_POLLS_AUDIO_GENERATION=240 # Max polling attempts for audio inference (default: 240, ~4 min total) +# Text Operations (milliseconds) +RUNWARE_TEXT_STREAM_TIMEOUT=600000 # Text inference streaming (SSE) read timeout (default: 10 min) + # Other Operations (milliseconds) RUNWARE_PROMPT_ENHANCE_TIMEOUT=60000 # Prompt enhancement (default: 1 min) RUNWARE_WEBHOOK_TIMEOUT=30000 # Webhook acknowledgment (default: 30 sec) diff --git a/requirements.txt b/requirements.txt index 611060a..f202503 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ aiofiles==23.2.1 +httpx>=0.27.0 python-dotenv==1.0.1 websockets>=12.0 \ No newline at end of file diff --git a/runware/base.py b/runware/base.py index ef819d6..563f8cd 100644 --- a/runware/base.py +++ b/runware/base.py @@ -1,13 +1,15 @@ import asyncio import inspect +import json import logging import os import re from asyncio import gather from dataclasses import asdict from random import uniform -from typing import List, Optional, Union, Callable, Any, Dict, Tuple +from typing import List, Optional, Union, Callable, Any, Dict, Tuple, AsyncIterator +import httpx from websockets.protocol import State from .logging_config import configure_logging @@ -58,11 +60,13 @@ IUploadMediaRequest, ITextInference, IText, + ITextInferenceUsage, ) from .types import IImage, IError, SdkType, ListenerType from .utils import ( BASE_RUNWARE_URLS, getUUID, + get_http_url_from_ws_url, fileToBase64, createImageFromResponse, createImageToTextFromResponse, @@ -81,6 +85,7 @@ createAsyncTaskResponse, VIDEO_INITIAL_TIMEOUT, TEXT_INITIAL_TIMEOUT, + TEXT_STREAM_READ_TIMEOUT, VIDEO_POLLING_DELAY, WEBHOOK_TIMEOUT, IMAGE_INFERENCE_TIMEOUT, @@ -2028,7 +2033,20 @@ async def _inference3d(self, request3d: I3dInference) -> Union[List[I3d], IAsync await self.ensureConnection() return await self._request3d(request3d) - async def textInference(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]: + async def textInference( + self, requestText: ITextInference + ) -> Union[List[IText], IAsyncTaskResponse, AsyncIterator[Union[str, IText]]]: + delivery_method_enum = ( + requestText.deliveryMethod + if isinstance(requestText.deliveryMethod, EDeliveryMethod) + else EDeliveryMethod(requestText.deliveryMethod) + ) + if delivery_method_enum == EDeliveryMethod.STREAM: + async def stream_with_semaphore() -> AsyncIterator[Union[str, IText]]: + async with self._request_semaphore: + async for chunk in self._requestTextStream(requestText): + yield chunk + return stream_with_semaphore() async with self._request_semaphore: return await self._retry_async_with_reconnect( self._requestText, @@ -2253,6 +2271,52 @@ def _buildTextRequest(self, requestText: ITextInference) -> Dict[str, Any]: self._addTextProviderSettings(request_object, requestText) return request_object + async def _requestTextStream( + self, requestText: ITextInference + ) -> AsyncIterator[Union[str, IText]]: + requestText.taskUUID = requestText.taskUUID or getUUID() + request_object = self._buildTextRequest(requestText) + body = [request_object] + http_url = get_http_url_from_ws_url(self._url or "") + headers = { + "Accept": "text/event-stream", + "Authorization": f"Bearer {self._apiKey}", + "Content-Type": "application/json", + } + try: + async with httpx.AsyncClient(timeout=TEXT_STREAM_READ_TIMEOUT / 1000) as client: + async with client.stream( + "POST", + http_url, + json=body, + headers=headers, + ) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + try: + line = json.loads(line.replace("data:", "", 1)) + except json.JSONDecodeError: + continue + data = line.get("data") or line + if data.get("error") is not None: + raise RunwareAPIError(data["error"]) + choice = (data.get("choices") or [{}])[0] + delta = choice.get("delta") or {} + if delta.get("content"): + yield delta.get("content") + if choice.get("finish_reason") is not None: + usage = instantiateDataclass(ITextInferenceUsage, data.get("usage")) + yield IText( + taskType=ETaskType.TEXT_INFERENCE.value, + taskUUID=data.get("taskUUID") or "", + finishReason=choice.get("finish_reason"), + usage=usage, + cost=data.get("cost"), + ) + return + except Exception as e: + raise RunwareAPIError({"message": str(e)}) + async def _requestText(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]: await self.ensureConnection() requestText.taskUUID = requestText.taskUUID or getUUID() diff --git a/runware/types.py b/runware/types.py index 7d32aa8..c19c54c 100644 --- a/runware/types.py +++ b/runware/types.py @@ -106,6 +106,7 @@ class EOpenPosePreProcessor(Enum): class EDeliveryMethod(Enum): SYNC = "sync" ASYNC = "async" + STREAM = "stream" class OperationState(Enum): """State machine for pending operations.""" diff --git a/runware/utils.py b/runware/utils.py index f590ffe..212a0f7 100644 --- a/runware/utils.py +++ b/runware/utils.py @@ -42,6 +42,25 @@ Environment.TEST: "ws://localhost:8080", } +# HTTP REST base URL for streaming (e.g. textInference with deliveryMethod=stream) +BASE_RUNWARE_HTTP_URLS = { + Environment.PRODUCTION: "https://api.runware.ai/v1", + Environment.TEST: "http://localhost:8080", +} + +# Map each WebSocket base URL to its HTTP counterpart (for streaming requests). +_WS_TO_HTTP = { + BASE_RUNWARE_URLS[Environment.PRODUCTION]: BASE_RUNWARE_HTTP_URLS[Environment.PRODUCTION], + BASE_RUNWARE_URLS[Environment.TEST]: BASE_RUNWARE_HTTP_URLS[Environment.TEST], +} + + +def get_http_url_from_ws_url(ws_url: str) -> str: + """Return the HTTP URL for this ws_url from _WS_TO_HTTP.""" + if not ws_url: + return BASE_RUNWARE_HTTP_URLS[Environment.PRODUCTION] + return _WS_TO_HTTP.get(ws_url, BASE_RUNWARE_HTTP_URLS[Environment.PRODUCTION]) + RETRY_SDK_COUNTS = { "GLOBAL": 2, @@ -125,6 +144,14 @@ 30000 )) +# Text streaming read timeout (milliseconds) +# Maximum time to wait for data on the SSE stream; long to avoid ReadTimeout mid-stream +# Used in: _requestTextStream() for deliveryMethod=stream +TEXT_STREAM_READ_TIMEOUT = int(os.environ.get( + "RUNWARE_TEXT_STREAM_TIMEOUT", + 600000 +)) + # Audio generation timeout (milliseconds) # Maximum time to wait for audio generation completion # Used in: _waitForAudioCompletion() for single audio generation