Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions runware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2545,6 +2545,7 @@ def _buildTextRequest(self, requestText: ITextInference) -> Dict[str, Any]:
request_object["includeUsage"] = requestText.includeUsage
if requestText.numberResults is not None:
request_object["numberResults"] = requestText.numberResults
self._addOptionalField(request_object, requestText.toolChoice)
self._addOptionalField(request_object, requestText.settings)
self._addOptionalField(request_object, requestText.inputs)
self._addProviderSettings(request_object, requestText)
Expand Down Expand Up @@ -2584,6 +2585,7 @@ async def _requestTextStream(
self, requestText: ITextInference
) -> AsyncIterator[Union[str, IText]]:
requestText.taskUUID = requestText.taskUUID or getUUID()
await self._processTextInputs(requestText)
request_object = self._buildTextRequest(requestText)
body = [request_object]
http_url = get_http_url_from_ws_url(self._url or "")
Expand Down Expand Up @@ -2649,18 +2651,7 @@ async def _requestTextStream(
async def _requestText(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]:
await self.ensureConnection()
requestText.taskUUID = requestText.taskUUID or getUUID()


if requestText.inputs:
inputs = requestText.inputs
if isinstance(inputs, dict):
inputs = ITextInputs(**inputs)
requestText.inputs = inputs

if inputs.images:
inputs.images = await self._process_media_list(inputs.images)
if inputs.videos:
inputs.videos = await self._process_media_list(inputs.videos)
await self._processTextInputs(requestText)

request_object = self._buildTextRequest(requestText)

Expand Down Expand Up @@ -2762,6 +2753,22 @@ def is_text_complete(r: "Dict[str, Any]") -> bool:
finally:
await self._unregister_pending_operation(task_uuid)

async def _processTextInputs(self, requestText: ITextInference) -> None:
if not requestText.inputs:
return

inputs = requestText.inputs
if isinstance(inputs, dict):
inputs = ITextInputs(**inputs)
requestText.inputs = inputs

if inputs.images:
inputs.images = await self._process_media_list(inputs.images)
if inputs.videos:
inputs.videos = await self._process_media_list(inputs.videos)
if inputs.documents:
inputs.documents = await self._process_media_list(inputs.documents)
Comment thread
Sirsho1997 marked this conversation as resolved.

Comment thread
Sirsho1997 marked this conversation as resolved.
Comment thread
Sirsho1997 marked this conversation as resolved.
def _buildImageRequest(self, requestImage: IImageInference, prompt: Optional[str], control_net_data_dicts: List[Dict], instant_id_data: Optional[Dict], ip_adapters_data: Optional[List[Dict]], ace_plus_plus_data: Optional[Dict], pulid_data: Optional[Dict], photo_maker_data: Optional[Dict]) -> Dict[str, Any]:
request_object = {
"taskType": ETaskType.IMAGE_INFERENCE.value,
Expand Down
103 changes: 95 additions & 8 deletions runware/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ class OperationState(Enum):
IOutputType = Literal["base64Data", "dataURI", "URL"]
IOutputFormat = Literal["JPG", "PNG", "WEBP", "SVG"]
IAudioOutputFormat = Literal["wav", "mp3", "pcm", "opus", "aac", "flac", "MP3"]
TextInferenceCacheScope = Literal["system", "system+history"]
TextInferenceCacheTtl = Literal["5m", "1h"]


@dataclass
Expand Down Expand Up @@ -821,22 +823,66 @@ def request_key(self) -> str:
return "texSlat"


@dataclass
class ITextInferenceCache(SerializableMixin):

scope: Optional[TextInferenceCacheScope] = None
ttl: Optional[TextInferenceCacheTtl] = None

@property
def request_key(self) -> str:
return "cache"


@dataclass
class ITextInferenceTool(SerializableMixin):
"""Tool definition for text inference (e.g. function-calling / JSON-schema tools)."""

name: str
description: str
input_schema: Dict[str, Any]
schema: Optional[Dict[str, Any]] = None
input_schema: Optional[Dict[str, Any]] = field(default=None, repr=False)
toolType: Optional[str] = None
Comment thread
Sirsho1997 marked this conversation as resolved.

def serialize(self) -> Dict[str, Any]:
data = super().serialize()
input_schema = data.pop("input_schema", None)
if data.get("schema") is None and input_schema is not None:
data["schema"] = input_schema
if self.toolType is not None:
data["type"] = self.toolType
data.pop("toolType", None)
return data
Comment thread
Sirsho1997 marked this conversation as resolved.


@dataclass
class ITextInferenceToolChoice(SerializableMixin):
"""Selects how tools are used (provider-specific shape, e.g. type + name)."""

type: str
toolType: Optional[str] = None
type: InitVar[Optional[str]] = None
name: Optional[str] = None
Comment thread
Sirsho1997 marked this conversation as resolved.

def __post_init__(self, type: Optional[str] = None) -> None:
if self.toolType is None and type is not None:
warnings.warn(
"ITextInferenceToolChoice(type=...) is deprecated; use toolType=... instead.",
DeprecationWarning,
stacklevel=2,
)
self.toolType = type

@property
def request_key(self) -> str:
return "toolChoice"

Comment thread
Sirsho1997 marked this conversation as resolved.
def serialize(self) -> Dict[str, Any]:
data = super().serialize()
if self.toolType is not None:
data["type"] = self.toolType
data.pop("toolType", None)
return data


@dataclass
class IColorPaletteEntry(SerializableMixin):
Expand Down Expand Up @@ -964,7 +1010,8 @@ class ISettings(SerializableMixin):
topK: Optional[int] = None
stopSequences: Optional[List[str]] = None
tools: Optional[List[Union[ITextInferenceTool, Dict[str, Any]]]] = None
toolChoice: Optional[Union[ITextInferenceToolChoice, Dict[str, Any]]] = None
toolChoice: InitVar[Optional[Union["ITextInferenceToolChoice", Dict[str, Any]]]] = None
cache: Optional[Union[ITextInferenceCache, Dict[str, Any]]] = None
Comment thread
Sirsho1997 marked this conversation as resolved.
# Image upscale
steps: Optional[int] = None
seed: Optional[int] = None
Expand All @@ -980,7 +1027,14 @@ class ISettings(SerializableMixin):
enhanceDetails: Optional[bool] = None
realism: Optional[bool] = None

def __post_init__(self):
def __post_init__(self, toolChoice: Optional[Union["ITextInferenceToolChoice", Dict[str, Any]]] = None):
if toolChoice is not None:
warnings.warn(
"ISettings(toolChoice=...) is deprecated; use ITextInference.toolChoice instead.",
DeprecationWarning,
stacklevel=2,
)
self.__dict__["toolChoice"] = toolChoice
if self.sparseStructure is not None and isinstance(self.sparseStructure, dict):
self.sparseStructure = ISparseStructure(**self.sparseStructure)
if self.shapeSlat is not None and isinstance(self.shapeSlat, dict):
Expand All @@ -989,11 +1043,17 @@ def __post_init__(self):
self.texSlat = ITexSlat(**self.texSlat)
if self.tools is not None:
self.tools = [
ITextInferenceTool(**t) if isinstance(t, dict) else t
ITextInferenceTool(
**(
{**{k: v for k, v in t.items() if k != "type"}, "toolType": t["type"]}
if isinstance(t, dict) and "toolType" not in t and "type" in t
else t
)
) if isinstance(t, dict) else t
Comment thread
Sirsho1997 marked this conversation as resolved.
for t in self.tools
]
if self.toolChoice is not None and isinstance(self.toolChoice, dict):
self.toolChoice = ITextInferenceToolChoice(**self.toolChoice)
if self.cache is not None and isinstance(self.cache, dict):
self.cache = ITextInferenceCache(**self.cache)
if self.editRegions is not None:
self.editRegions = [
[
Expand Down Expand Up @@ -1083,6 +1143,7 @@ def __post_init__(self):
class ITextInputs(SerializableMixin):
images: Optional[List[Union[str, File]]] = None
videos: Optional[List[Union[str, File]]] = None
documents: Optional[List[Union[str, File]]] = None
Comment thread
Sirsho1997 marked this conversation as resolved.

@property
def request_key(self) -> str:
Expand Down Expand Up @@ -2138,14 +2199,40 @@ class ITextInference:
seed: Optional[int] = None
includeCost: Optional[bool] = None
includeUsage: Optional[bool] = None
toolChoice: Optional[Union[ITextInferenceToolChoice, Dict[str, Any]]] = None
settings: Optional[Union[ISettings, Dict[str, Any]]] = None
inputs: Optional[Union[ITextInputs, Dict[str, Any]]] = None
providerSettings: Optional[TextProviderSettings] = None
webhookURL: Optional[str] = None

def __post_init__(self) -> None:
if self.settings is not None and isinstance(self.settings, dict):
self.settings = ISettings(**self.settings)
settings_data = dict(self.settings)
legacy_tool_choice = settings_data.pop("toolChoice", None)
if legacy_tool_choice is not None:
warnings.warn(
"settings.toolChoice is deprecated; use ITextInference.toolChoice instead.",
DeprecationWarning,
stacklevel=2,
)
if self.toolChoice is None:
self.toolChoice = legacy_tool_choice
self.settings = ISettings(**settings_data)
elif self.settings is not None:
legacy_tool_choice = getattr(self.settings, "toolChoice", None)
if legacy_tool_choice is not None:
warnings.warn(
"settings.toolChoice is deprecated; use ITextInference.toolChoice instead.",
DeprecationWarning,
stacklevel=2,
)
if self.toolChoice is None:
self.toolChoice = legacy_tool_choice
if self.toolChoice is not None and isinstance(self.toolChoice, dict):
tool_choice_data = dict(self.toolChoice)
if "toolType" not in tool_choice_data and "type" in tool_choice_data:
tool_choice_data["toolType"] = tool_choice_data.pop("type")
self.toolChoice = ITextInferenceToolChoice(**tool_choice_data)
if self.inputs is not None and isinstance(self.inputs, dict):
self.inputs = ITextInputs(**self.inputs)

Expand Down
2 changes: 1 addition & 1 deletion runware/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
mimetypes.add_type("image/webp", ".webp")

BASE_RUNWARE_URLS = {
Environment.PRODUCTION: "wss://ws-api.runware.ai/v1",
Environment.PRODUCTION: "wss://ws-api.runware.dev/v1",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Environment.TEST: "ws://localhost:8080",
}

Expand Down