diff --git a/README.md b/README.md index fa5faf3..5008ca0 100644 --- a/README.md +++ b/README.md @@ -201,44 +201,6 @@ Your webhook endpoint will receive a POST request with the same format as synchr } ``` -### Video Inference with Skip Response - -For long-running video generation tasks, you can use `skipResponse` to submit the task and retrieve results later. This is useful for handling system interruptions, batch processing, or building queue-based systems. -```python -from runware import Runware, IVideoInference - -async def main() -> None: - runware = Runware(api_key=RUNWARE_API_KEY) - await runware.connect() - - # Submit video task without waiting - request = IVideoInference( - model="openai:3@2", - positivePrompt="A beautiful sunset over the ocean", - duration=4, - width=1280, - height=720, - skipResponse=True, - ) - - response = await runware.videoInference(requestVideo=request) - task_uuid = response.taskUUID - print(f"Task submitted: {task_uuid}") - - # Later, retrieve results - videos = await runware.getResponse( - taskUUID=task_uuid, - numberResults=1 - ) - - for video in videos: - print(f"Video URL: {video.videoURL}") -``` - -**Parameters:** -- `skipResponse`: Set to `True` to return immediately with `taskUUID` instead of waiting for completion -- Use `getResponse(taskUUID)` to retrieve results at any time - ### Video Inference with Async Delivery Method For long-running video generation tasks, you can use `deliveryMethod="async"` to submit the task and retrieve results later. This is useful for handling system interruptions, batch processing, or building queue-based systems. @@ -1129,4 +1091,4 @@ async def main(): # Your code here ``` -**Note:** For long-running video operations, consider using webhooks or `skipResponse=True` to avoid timeout issues with extended generation times. \ No newline at end of file +**Note:** For long-running video operations, consider using webhooks or `deliveryMethod="async"` to avoid timeout issues with extended generation times. \ No newline at end of file diff --git a/runware/base.py b/runware/base.py index ef819d6..db4ee47 100644 --- a/runware/base.py +++ b/runware/base.py @@ -4,7 +4,8 @@ import os import re from asyncio import gather -from dataclasses import asdict +from dataclasses import asdict, is_dataclass, fields +from enum import Enum from random import uniform from typing import List, Optional, Union, Callable, Any, Dict, Tuple @@ -772,6 +773,7 @@ async def _imageInference( task_uuid = requestImage.taskUUID number_results = requestImage.numberResults or 1 + if delivery_method_enum is EDeliveryMethod.ASYNC: if requestImage.webhookURL: request_object["webhookURL"] = requestImage.webhookURL @@ -2067,13 +2069,6 @@ async def _requestVideo(self, requestVideo: "IVideoInference") -> "Union[List[IV if requestVideo.webhookURL: request_object["webhookURL"] = requestVideo.webhookURL - if requestVideo.skipResponse: - await self.send([request_object]) - return IAsyncTaskResponse( - taskType=ETaskType.VIDEO_INFERENCE.value, - taskUUID=requestVideo.taskUUID - ) - return await self._handleInitialVideoResponse( request_object=request_object, task_uuid=requestVideo.taskUUID, @@ -2132,8 +2127,9 @@ def _buildVideoRequest(self, requestVideo: IVideoInference) -> Dict[str, Any]: if requestVideo.positivePrompt is not None: request_object["positivePrompt"] = requestVideo.positivePrompt.strip() + self._addOptionalBuiltInDataTypesFields(request_object, requestVideo) + self._addOptionalField(request_object, requestVideo.speech) - self._addOptionalVideoFields(request_object, requestVideo) self._addVideoImages(request_object, requestVideo) self._addOptionalField(request_object, requestVideo.inputs) self._addProviderSettings(request_object, requestVideo) @@ -2144,18 +2140,6 @@ def _buildVideoRequest(self, requestVideo: IVideoInference) -> Dict[str, Any]: return request_object - def _addOptionalVideoFields(self, request_object: Dict[str, Any], requestVideo: IVideoInference) -> None: - optional_fields = [ - "outputType", "outputFormat", "outputQuality", "uploadEndpoint", - "includeCost", "negativePrompt", "inputAudios", "referenceVideos", "fps", "steps", "scheduler", "seed", - "CFGScale", "seedImage", "duration", "width", "height", "nsfw_check", "resolution", - ] - - for field in optional_fields: - value = getattr(requestVideo, field, None) - if value is not None: - request_object[field] = value - def _addVideoImages(self, request_object: Dict[str, Any], requestVideo: IVideoInference) -> None: if requestVideo.frameImages: frame_images_data = [] @@ -2367,7 +2351,7 @@ def _buildImageRequest(self, requestImage: IImageInference, prompt: Optional[str if prompt: request_object["positivePrompt"] = prompt - self._addOptionalImageFields(request_object, requestImage) + self._addOptionalBuiltInDataTypesFields(request_object, requestImage) self._addImageSpecialFields(request_object, requestImage, control_net_data_dicts, instant_id_data, ip_adapters_data, ace_plus_plus_data, pulid_data) self._addOptionalField(request_object, requestImage.inputs) self._addImageProviderSettings(request_object, requestImage) @@ -2378,24 +2362,6 @@ def _buildImageRequest(self, requestImage: IImageInference, prompt: Optional[str return request_object - def _addOptionalImageFields(self, request_object: Dict[str, Any], requestImage: IImageInference) -> None: - optional_fields = [ - "outputType", "outputFormat", "outputQuality", "uploadEndpoint", - "includeCost", "checkNsfw", "negativePrompt", "seedImage", "maskImage", - "strength", "height", "width", "steps", "scheduler", "seed", "CFGScale", - "clipSkip", "promptWeighting", "maskMargin", "vae", "webhookURL", "acceleration", - "useCache", "ttl", "resolution" - ] - - for field in optional_fields: - value = getattr(requestImage, field, None) - if value is not None: - # Special handling for checkNsfw -> checkNSFW - if field == "checkNsfw": - request_object["checkNSFW"] = value - else: - request_object[field] = value - def _addImageSpecialFields(self, request_object: Dict[str, Any], requestImage: IImageInference, control_net_data_dicts: List[Dict], instant_id_data: Optional[Dict], ip_adapters_data: Optional[List[Dict]], ace_plus_plus_data: Optional[Dict], pulid_data: Optional[Dict]) -> None: # Add controlNet if present if control_net_data_dicts: @@ -2476,6 +2442,49 @@ def _addImageSpecialFields(self, request_object: Dict[str, Any], requestImage: I if hasattr(requestImage, "extraArgs") and isinstance(requestImage.extraArgs, dict): request_object.update(requestImage.extraArgs) + def _convert_enums(self, val: Any) -> Any: + if is_dataclass(val): + return val + if isinstance(val, Enum): + return val.value + if isinstance(val, list): + return [self._convert_enums(v) for v in val] + if isinstance(val, tuple): + return tuple(self._convert_enums(v) for v in val) + if isinstance(val, dict): + return { + self._convert_enums(k) if isinstance(k, Enum) else k: self._convert_enums(v) + for k, v in val.items() + } + return val + + def _addOptionalBuiltInDataTypesFields(self, request_object: Dict[str, Any], obj: Any) -> None: + if not is_dataclass(obj): + return + + cls = obj.__class__ + + for field in fields(cls): + name = field.name + value = getattr(obj, name, None) + + if ( + name in request_object + or name == "extraArgs" + or value is None + or (isinstance(value, (list, tuple, dict)) and not value) + or callable(value) + or is_dataclass(value) + or ( + isinstance(value, (list, tuple)) + and value + and any(is_dataclass(v) for v in value) + ) + ): + continue + + request_object[name] = self._convert_enums(value) + def _addSafetySettings(self, request_object: Dict[str, Any], safety: ISafety) -> None: safety_dict = asdict(safety) safety_dict = {k: v for k, v in safety_dict.items() if v is not None} @@ -2835,6 +2844,7 @@ async def _requestAudio(self, requestAudio: "IAudioInference") -> Union[List["IA requestAudio.taskUUID = requestAudio.taskUUID or getUUID() request_object = self._buildAudioRequest(requestAudio) + return await self._handleInitialAudioResponse( request_object=request_object, task_uuid=requestAudio.taskUUID, @@ -2861,26 +2871,14 @@ def _buildAudioRequest(self, requestAudio: IAudioInference) -> Dict[str, Any]: if requestAudio.duration is not None: request_object["duration"] = requestAudio.duration - self._addOptionalAudioFields(request_object, requestAudio) + self._addOptionalBuiltInDataTypesFields(request_object, requestAudio) self._addOptionalField(request_object, requestAudio.speech) self._addOptionalField(request_object, requestAudio.audioSettings) self._addOptionalField(request_object, requestAudio.settings) self._addAudioProviderSettings(request_object, requestAudio) self._addOptionalField(request_object, requestAudio.inputs) - self._addOptionalField(request_object, requestAudio.settings) - - return request_object - - def _addOptionalAudioFields(self, request_object: Dict[str, Any], requestAudio: IAudioInference) -> None: - optional_fields = [ - "outputType", "outputFormat", "includeCost", "uploadEndpoint", "webhookURL", - "negativePrompt", "steps", "seed", "CFGScale", "strength" - ] - for field in optional_fields: - value = getattr(requestAudio, field, None) - if value is not None: - request_object[field] = value + return request_object def _addAudioProviderSettings(self, request_object: Dict[str, Any], requestAudio: IAudioInference) -> None: diff --git a/runware/types.py b/runware/types.py index 7d32aa8..34c565b 100644 --- a/runware/types.py +++ b/runware/types.py @@ -1,6 +1,6 @@ from abc import abstractmethod, ABC from enum import Enum -from dataclasses import dataclass, field, asdict +from dataclasses import dataclass, field, asdict, InitVar from typing import List, Union, Optional, Callable, Any, Dict, TypeVar, Literal import warnings @@ -916,7 +916,7 @@ class IImageInference: outputType: Optional[IOutputType] = None outputFormat: Optional[IOutputFormat] = None uploadEndpoint: Optional[str] = None - checkNsfw: Optional[bool] = None + checkNsfw: InitVar[Optional[bool]] = None negativePrompt: Optional[str] = None seedImage: Optional[Union[File, str]] = None maskImage: Optional[Union[File, str]] = None @@ -960,7 +960,19 @@ class IImageInference: webhookURL: Optional[str] = None ttl: Optional[int] = None # time-to-live (TTL) in seconds, only applies when outputType is "URL" - def __post_init__(self): + def __post_init__(self, checkNsfw: Optional[bool] = None): + if checkNsfw is not None: + warnings.warn( + "checkNsfw has been deprecated and will be removed in a future version; please use safety.checkContent instead.", + DeprecationWarning, + stacklevel=2, + ) + if checkNsfw: + if isinstance(self.safety, dict): + self.safety.setdefault("checkContent", True) + elif self.safety is not None and hasattr(self.safety, "checkContent"): + if getattr(self.safety, "checkContent") is None: + self.safety.checkContent = True if self.safety is not None and isinstance(self.safety, dict): self.safety = ISafety(**self.safety) if self.settings is not None and isinstance(self.settings, dict): @@ -1485,15 +1497,21 @@ class IVideoInference: advancedFeatures: Optional[IVideoAdvancedFeatures] = None acceleratorOptions: Optional[IAcceleratorOptions] = None inputs: Optional[Union[IVideoInputs, Dict[str, Any]]] = None - skipResponse: Optional[bool] = False resolution: Optional[str] = None settings: Optional[Union[ISettings, Dict[str, Any]]] = None + skipResponse: InitVar[Optional[bool]] = None - def __post_init__(self): + def __post_init__(self, skipResponse: Optional[bool] = None) -> None: + if skipResponse is not None: + warnings.warn( + "skipResponse has been deprecated; use deliveryMethod='async' instead", + DeprecationWarning, + stacklevel=2, + ) + if skipResponse and getattr(self, "deliveryMethod", None) is None: + self.deliveryMethod = "async" if self.settings is not None and isinstance(self.settings, dict): self.settings = ISettings(**self.settings) - - def __post_init__(self): if self.safety is not None and isinstance(self.safety, dict): self.safety = ISafety(**self.safety) if self.inputs is not None and isinstance(self.inputs, dict):