diff --git a/runware/base.py b/runware/base.py index 499838d..8650e75 100644 --- a/runware/base.py +++ b/runware/base.py @@ -2545,6 +2545,7 @@ def _buildTextRequest(self, requestText: ITextInference) -> Dict[str, Any]: request_object["includeUsage"] = requestText.includeUsage if requestText.numberResults is not None: request_object["numberResults"] = requestText.numberResults + self._addOptionalField(request_object, requestText.toolChoice) self._addOptionalField(request_object, requestText.settings) self._addOptionalField(request_object, requestText.inputs) self._addProviderSettings(request_object, requestText) @@ -2584,6 +2585,7 @@ async def _requestTextStream( self, requestText: ITextInference ) -> AsyncIterator[Union[str, IText]]: requestText.taskUUID = requestText.taskUUID or getUUID() + await self._processTextInputs(requestText) request_object = self._buildTextRequest(requestText) body = [request_object] http_url = get_http_url_from_ws_url(self._url or "") @@ -2649,18 +2651,7 @@ async def _requestTextStream( async def _requestText(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]: await self.ensureConnection() requestText.taskUUID = requestText.taskUUID or getUUID() - - - if requestText.inputs: - inputs = requestText.inputs - if isinstance(inputs, dict): - inputs = ITextInputs(**inputs) - requestText.inputs = inputs - - if inputs.images: - inputs.images = await self._process_media_list(inputs.images) - if inputs.videos: - inputs.videos = await self._process_media_list(inputs.videos) + await self._processTextInputs(requestText) request_object = self._buildTextRequest(requestText) @@ -2762,6 +2753,22 @@ def is_text_complete(r: "Dict[str, Any]") -> bool: finally: await self._unregister_pending_operation(task_uuid) + async def _processTextInputs(self, requestText: ITextInference) -> None: + if not requestText.inputs: + return + + inputs = requestText.inputs + if isinstance(inputs, dict): + inputs = ITextInputs(**inputs) + requestText.inputs = inputs + + if inputs.images: + inputs.images = await self._process_media_list(inputs.images) + if inputs.videos: + inputs.videos = await self._process_media_list(inputs.videos) + if inputs.documents: + inputs.documents = await self._process_media_list(inputs.documents) + def _buildImageRequest(self, requestImage: IImageInference, prompt: Optional[str], control_net_data_dicts: List[Dict], instant_id_data: Optional[Dict], ip_adapters_data: Optional[List[Dict]], ace_plus_plus_data: Optional[Dict], pulid_data: Optional[Dict], photo_maker_data: Optional[Dict]) -> Dict[str, Any]: request_object = { "taskType": ETaskType.IMAGE_INFERENCE.value, diff --git a/runware/types.py b/runware/types.py index 76a62fa..f1e577c 100644 --- a/runware/types.py +++ b/runware/types.py @@ -120,6 +120,8 @@ class OperationState(Enum): IOutputType = Literal["base64Data", "dataURI", "URL"] IOutputFormat = Literal["JPG", "PNG", "WEBP", "SVG"] IAudioOutputFormat = Literal["wav", "mp3", "pcm", "opus", "aac", "flac", "MP3"] +TextInferenceCacheScope = Literal["system", "system+history"] +TextInferenceCacheTtl = Literal["5m", "1h"] @dataclass @@ -821,22 +823,66 @@ def request_key(self) -> str: return "texSlat" +@dataclass +class ITextInferenceCache(SerializableMixin): + + scope: Optional[TextInferenceCacheScope] = None + ttl: Optional[TextInferenceCacheTtl] = None + + @property + def request_key(self) -> str: + return "cache" + + @dataclass class ITextInferenceTool(SerializableMixin): """Tool definition for text inference (e.g. function-calling / JSON-schema tools).""" name: str description: str - input_schema: Dict[str, Any] + schema: Optional[Dict[str, Any]] = None + input_schema: Optional[Dict[str, Any]] = field(default=None, repr=False) + toolType: Optional[str] = None + + def serialize(self) -> Dict[str, Any]: + data = super().serialize() + input_schema = data.pop("input_schema", None) + if data.get("schema") is None and input_schema is not None: + data["schema"] = input_schema + if self.toolType is not None: + data["type"] = self.toolType + data.pop("toolType", None) + return data @dataclass class ITextInferenceToolChoice(SerializableMixin): """Selects how tools are used (provider-specific shape, e.g. type + name).""" - type: str + toolType: Optional[str] = None + type: InitVar[Optional[str]] = None name: Optional[str] = None + def __post_init__(self, type: Optional[str] = None) -> None: + if self.toolType is None and type is not None: + warnings.warn( + "ITextInferenceToolChoice(type=...) is deprecated; use toolType=... instead.", + DeprecationWarning, + stacklevel=2, + ) + self.toolType = type + + @property + def request_key(self) -> str: + return "toolChoice" + + def serialize(self) -> Dict[str, Any]: + data = super().serialize() + if self.toolType is not None: + data["type"] = self.toolType + data.pop("toolType", None) + return data + @dataclass class IColorPaletteEntry(SerializableMixin): @@ -964,7 +1010,8 @@ class ISettings(SerializableMixin): topK: Optional[int] = None stopSequences: Optional[List[str]] = None tools: Optional[List[Union[ITextInferenceTool, Dict[str, Any]]]] = None - toolChoice: Optional[Union[ITextInferenceToolChoice, Dict[str, Any]]] = None + toolChoice: InitVar[Optional[Union["ITextInferenceToolChoice", Dict[str, Any]]]] = None + cache: Optional[Union[ITextInferenceCache, Dict[str, Any]]] = None # Image upscale steps: Optional[int] = None seed: Optional[int] = None @@ -980,7 +1027,14 @@ class ISettings(SerializableMixin): enhanceDetails: Optional[bool] = None realism: Optional[bool] = None - def __post_init__(self): + def __post_init__(self, toolChoice: Optional[Union["ITextInferenceToolChoice", Dict[str, Any]]] = None): + if toolChoice is not None: + warnings.warn( + "ISettings(toolChoice=...) is deprecated; use ITextInference.toolChoice instead.", + DeprecationWarning, + stacklevel=2, + ) + self.__dict__["toolChoice"] = toolChoice if self.sparseStructure is not None and isinstance(self.sparseStructure, dict): self.sparseStructure = ISparseStructure(**self.sparseStructure) if self.shapeSlat is not None and isinstance(self.shapeSlat, dict): @@ -989,11 +1043,17 @@ def __post_init__(self): self.texSlat = ITexSlat(**self.texSlat) if self.tools is not None: self.tools = [ - ITextInferenceTool(**t) if isinstance(t, dict) else t + ITextInferenceTool( + **( + {**{k: v for k, v in t.items() if k != "type"}, "toolType": t["type"]} + if isinstance(t, dict) and "toolType" not in t and "type" in t + else t + ) + ) if isinstance(t, dict) else t for t in self.tools ] - if self.toolChoice is not None and isinstance(self.toolChoice, dict): - self.toolChoice = ITextInferenceToolChoice(**self.toolChoice) + if self.cache is not None and isinstance(self.cache, dict): + self.cache = ITextInferenceCache(**self.cache) if self.editRegions is not None: self.editRegions = [ [ @@ -1083,6 +1143,7 @@ def __post_init__(self): class ITextInputs(SerializableMixin): images: Optional[List[Union[str, File]]] = None videos: Optional[List[Union[str, File]]] = None + documents: Optional[List[Union[str, File]]] = None @property def request_key(self) -> str: @@ -2138,6 +2199,7 @@ class ITextInference: seed: Optional[int] = None includeCost: Optional[bool] = None includeUsage: Optional[bool] = None + toolChoice: Optional[Union[ITextInferenceToolChoice, Dict[str, Any]]] = None settings: Optional[Union[ISettings, Dict[str, Any]]] = None inputs: Optional[Union[ITextInputs, Dict[str, Any]]] = None providerSettings: Optional[TextProviderSettings] = None @@ -2145,7 +2207,32 @@ class ITextInference: def __post_init__(self) -> None: if self.settings is not None and isinstance(self.settings, dict): - self.settings = ISettings(**self.settings) + settings_data = dict(self.settings) + legacy_tool_choice = settings_data.pop("toolChoice", None) + if legacy_tool_choice is not None: + warnings.warn( + "settings.toolChoice is deprecated; use ITextInference.toolChoice instead.", + DeprecationWarning, + stacklevel=2, + ) + if self.toolChoice is None: + self.toolChoice = legacy_tool_choice + self.settings = ISettings(**settings_data) + elif self.settings is not None: + legacy_tool_choice = getattr(self.settings, "toolChoice", None) + if legacy_tool_choice is not None: + warnings.warn( + "settings.toolChoice is deprecated; use ITextInference.toolChoice instead.", + DeprecationWarning, + stacklevel=2, + ) + if self.toolChoice is None: + self.toolChoice = legacy_tool_choice + if self.toolChoice is not None and isinstance(self.toolChoice, dict): + tool_choice_data = dict(self.toolChoice) + if "toolType" not in tool_choice_data and "type" in tool_choice_data: + tool_choice_data["toolType"] = tool_choice_data.pop("type") + self.toolChoice = ITextInferenceToolChoice(**tool_choice_data) if self.inputs is not None and isinstance(self.inputs, dict): self.inputs = ITextInputs(**self.inputs) diff --git a/runware/utils.py b/runware/utils.py index 9c05871..704fa94 100644 --- a/runware/utils.py +++ b/runware/utils.py @@ -39,7 +39,7 @@ mimetypes.add_type("image/webp", ".webp") BASE_RUNWARE_URLS = { - Environment.PRODUCTION: "wss://ws-api.runware.ai/v1", + Environment.PRODUCTION: "wss://ws-api.runware.dev/v1", Environment.TEST: "ws://localhost:8080", }