Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,67 @@ async def main() -> None:
- Use `getResponse(taskUUID)` to retrieve results at any time
- `deliveryMethod="sync"` waits for complete results (may timeout for long-running tasks)

### Training (Async Only)

Training is a long-running task type.
- `deliveryMethod="sync"` is not supported for training and raises a `ValueError`
- Use `deliveryMethod="async"` and retrieve final results with `getResponse(taskUUID)`

```python
import uuid
from runware import Runware, ITraining, ITrainingImportModel, ITrainingInputs

async def main() -> None:
runware = Runware(api_key=RUNWARE_API_KEY)
await runware.connect()

request_training = ITraining(
taskUUID=str(uuid.uuid4()),
model="runware:illustrative@training",
deliveryMethod="async",
importModel=ITrainingImportModel(
air="runware:illustrative@0",
name="Runware Illustrative Training Model",
uniqueIdentifier="runware_illustrative_model_1",
version="1.0",
private=False,
heroImageURL="https://example.com/hero-image.png",
shortDescription="First training model",
),
inputs=ITrainingInputs(
dataset="example/pictures.zip"
),
)

training_task = await runware.training(requestTraining=request_training)
results = await runware.getResponse(taskUUID=training_task.taskUUID)
print(results)
```

After training, you can run image inference with the trained model:

```python
import uuid
from runware import Runware, IImageInference, IInputs, IInputReference

async def main() -> None:
runware = Runware(api_key=RUNWARE_API_KEY)
await runware.connect()

request_image = IImageInference(
taskUUID=str(uuid.uuid4()),
model="runware:illustrative@0",
positivePrompt="a horse",
numberResults=1,
width=1024,
height=1025,
deliveryMethod="sync",
)

image_task = await runware.imageInference(requestImage=request_image)
print(f"Image inference task submitted: {image_task.taskUUID}")
```

### Retrieving Original Task Request/Response

To inspect the original request payload and response for a past task, use `getTaskDetails(taskUUID)`.
Expand Down
141 changes: 137 additions & 4 deletions runware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
OperationState,
I3dInference,
I3d,
ITraining,
ITrainingResult,
IGetResponseRequest,
IGetTaskDetailsRequest,
IUploadImageRequest,
Expand Down Expand Up @@ -2169,6 +2171,14 @@ async def _inference3d(self, request3d: I3dInference) -> Union[List[I3d], IAsync
await self.ensureConnection()
return await self._request3d(request3d)

async def training(self, requestTraining: ITraining) -> Union[List[ITrainingResult], IAsyncTaskResponse]:
async with self._request_semaphore:
return await self._retry_async_with_reconnect(
self._requestTraining,
requestTraining,
task_type=ETaskType.TRAINING.value,
)

Comment thread
Sirsho1997 marked this conversation as resolved.
async def textInference(
self, requestText: ITextInference
) -> Union[List[IText], IAsyncTaskResponse, AsyncIterator[Union[str, IText]]]:
Expand All @@ -2194,7 +2204,7 @@ async def getResponse(
self,
taskUUID: str,
numberResults: Optional[int] = 1,
) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d], List[IText]]:
) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d], List[IText], List[ITrainingResult]]:
async with self._request_semaphore:
request = IGetResponseRequest(
taskUUID=taskUUID,
Expand All @@ -2205,7 +2215,7 @@ async def getResponse(
async def _getResponse(
self,
request_model: IGetResponseRequest,
) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d], List[IText]]:
) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d], List[IText], List[ITrainingResult]]:
await self.ensureConnection()

return await self._pollResults(
Expand Down Expand Up @@ -2284,6 +2294,7 @@ def _normalizeTaskDetailsRequest(self, request_items: List[Any]) -> List[Any]:
ETaskType.VIDEO_UPSCALE.value: IVideoUpscale,
ETaskType.AUDIO_INFERENCE.value: IAudioInference,
ETaskType.INFERENCE_3D.value: I3dInference,
ETaskType.TRAINING.value: ITraining,
ETaskType.TEXT_INFERENCE.value: ITextInference,
ETaskType.GET_RESPONSE.value: IGetResponseRequest,
ETaskType.GET_TASK_DETAILS.value: IGetTaskDetailsRequest,
Expand Down Expand Up @@ -2314,6 +2325,7 @@ def _normalizeTaskDetailsResponse(self, response_payload: Any) -> Any:
ETaskType.VIDEO_BACKGROUND_REMOVAL.value: IVideo,
ETaskType.VIDEO_UPSCALE.value: IVideo,
ETaskType.INFERENCE_3D.value: I3d,
ETaskType.TRAINING.value: ITrainingResult,
ETaskType.TEXT_INFERENCE.value: IText,
ETaskType.PROMPT_ENHANCE.value: IEnhancedPrompt,
ETaskType.GET_TASK_DETAILS.value: ITaskDetails,
Expand Down Expand Up @@ -2529,6 +2541,41 @@ async def _request3d(self, request3d: I3dInference) -> Union[List[I3d], IAsyncTa
debug_key="3d-inference-initial",
)

async def _requestTraining(self, requestTraining: ITraining) -> Union[List[ITrainingResult], IAsyncTaskResponse]:
await self.ensureConnection()
requestTraining.taskUUID = requestTraining.taskUUID or getUUID()

if requestTraining.inputs.dataset:
requestTraining.inputs.dataset = await self._process_media(requestTraining.inputs.dataset)
Comment thread
Sirsho1997 marked this conversation as resolved.
Comment thread
Sirsho1997 marked this conversation as resolved.

request_object = self._buildTrainingRequest(requestTraining)

return await self._handleInitialTrainingResponse(
request_object=request_object,
task_uuid=requestTraining.taskUUID,
number_results=1,
delivery_method=requestTraining.deliveryMethod,
webhook_url=requestTraining.webhookURL,
debug_key="training-initial",
)

def _buildTrainingRequest(self, requestTraining: ITraining) -> Dict[str, Any]:
request_object: Dict[str, Any] = {
"taskType": ETaskType.TRAINING.value,
"taskUUID": requestTraining.taskUUID,
"model": requestTraining.model,
"deliveryMethod": requestTraining.deliveryMethod,
}

if requestTraining.includeCost is not None:
request_object["includeCost"] = requestTraining.includeCost
if requestTraining.webhookURL is not None:
request_object["webhookURL"] = requestTraining.webhookURL

self._addOptionalField(request_object, requestTraining.importModel)
self._addOptionalField(request_object, requestTraining.inputs)
return request_object

def _buildTextRequest(self, requestText: ITextInference) -> Dict[str, Any]:
request_object: Dict[str, Any] = {
"taskType": ETaskType.TEXT_INFERENCE.value,
Expand Down Expand Up @@ -3161,6 +3208,85 @@ def is_3d_complete(r: Dict[str, Any]) -> bool:
finally:
await self._unregister_pending_operation(task_uuid)

async def _handleInitialTrainingResponse(
self,
request_object: Dict[str, Any],
task_uuid: str,
number_results: int,
delivery_method: Union[str, EDeliveryMethod] = EDeliveryMethod.ASYNC,
webhook_url: Optional[str] = None,
debug_key: str = "training-initial",
) -> Union[List[ITrainingResult], IAsyncTaskResponse]:
delivery_method_enum = (
delivery_method
if isinstance(delivery_method, EDeliveryMethod)
else EDeliveryMethod(delivery_method)
)

def is_training_complete(r: Dict[str, Any]) -> bool:
if r.get("status") == "success":
return True
if r.get("outputs") is not None:
return True
if webhook_url or delivery_method_enum is EDeliveryMethod.ASYNC:
return True
return False

if delivery_method_enum is EDeliveryMethod.SYNC and not webhook_url:
future, should_send = await self._register_pending_operation(
task_uuid,
expected_results=number_results or 1,
complete_predicate=None,
result_filter=lambda r: (
r.get("status") == "success"
or r.get("outputs") is not None
),
)
else:
future, should_send = await self._register_pending_operation(
task_uuid,
expected_results=1,
complete_predicate=is_training_complete,
)

timeout = TIMEOUT_DURATION if delivery_method_enum is EDeliveryMethod.SYNC else VIDEO_INITIAL_TIMEOUT

try:
if should_send:
await self.send([request_object])
await self._mark_operation_sent(task_uuid)

results = await asyncio.wait_for(future, timeout=timeout / 1000)
if not results:
raise ConnectionError(
f"No initial response received for training | "
f"delivery_method={delivery_method_enum} | taskUUID={task_uuid}"
)

response = results[0]
self._handle_error_response(response)

if response.get("status") == "success" or response.get("outputs") is not None:
return instantiateDataclassList(ITrainingResult, results)

if webhook_url or delivery_method_enum is EDeliveryMethod.ASYNC:
return createAsyncTaskResponse(response)

return instantiateDataclassList(ITrainingResult, results)

except asyncio.TimeoutError:
if not self.connected() or not self.isWebsocketReadyState():
raise ConnectionError(
f"Connection lost while waiting for training response | "
f"TaskUUID: {task_uuid} | Delivery method: {delivery_method_enum}"
)
raise ConnectionError(
f"Timeout waiting for training response | TaskUUID: {task_uuid} | "
f"Timeout: {timeout}ms"
)
finally:
await self._unregister_pending_operation(task_uuid)

async def _handleInitialImageResponse(
self,
task_uuid: str,
Expand Down Expand Up @@ -3422,15 +3548,15 @@ async def _pollResults(
self,
task_uuid: str,
number_results: Optional[int],
) -> Union[List[IVideo], List[IVideoToText], List[IAudio], List[IImage], List[I3d], List[IText]]:
) -> Union[List[IVideo], List[IVideoToText], List[IAudio], List[IImage], List[I3d], List[IText], List[ITrainingResult]]:
# Default to 1 if number_results is None
if number_results is None:
number_results = 1

completed_results: "List[Dict[str, Any]]" = []

task_type = None
response_cls: Optional[Union[IVideo, IVideoToText, IAudio, IImage, I3d, IText]] = None
response_cls: Optional[Union[IVideo, IVideoToText, IAudio, IImage, I3d, IText, ITrainingResult]] = None
max_polls: int = MAX_POLLS
polling_delay: int = VIDEO_POLLING_DELAY
timeout_message: str = f"Polling timeout after {MAX_POLLS} polls"
Expand Down Expand Up @@ -3497,6 +3623,13 @@ def configure_from_task_type(task_type_val: Optional[str]):
TEXT_POLLING_DELAY,
f"Text generation timeout after {MAX_POLLS} polls"
)
case ETaskType.TRAINING.value:
return (
ITrainingResult,
MAX_POLLS,
VIDEO_POLLING_DELAY,
f"Training timeout after {MAX_POLLS} polls"
)
Comment thread
Sirsho1997 marked this conversation as resolved.
case _:
raise ValueError(f"Unsupported task type for polling: {task_type_val}")

Expand Down
55 changes: 55 additions & 0 deletions runware/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ETaskType(Enum):
GET_RESPONSE = "getResponse"
GET_TASK_DETAILS = "getTaskDetails"
IMAGE_VECTORIZE = "vectorize"
TRAINING = "training"


class EPreProcessorGroup(Enum):
Expand Down Expand Up @@ -1239,6 +1240,32 @@ def request_key(self) -> str:
return "inputs"


@dataclass
class ITrainingImportModel(SerializableMixin):
air: str
name: str
uniqueIdentifier: str
version: str
private: bool
heroImageURL: Optional[str] = None
shortDescription: Optional[str] = None
architecture: Optional[str] = None
category: Optional[str] = None

@property
def request_key(self) -> str:
return "importModel"


@dataclass
class ITrainingInputs(SerializableMixin):
dataset: Optional[Union[str, File]] = None
Comment thread
Sirsho1997 marked this conversation as resolved.

@property
def request_key(self) -> str:
return "inputs"


@dataclass
class IImageInference:
model: Union[int, str]
Expand Down Expand Up @@ -1944,6 +1971,34 @@ def __post_init__(self):
self.inputs = I3dInputs(**self.inputs)


@dataclass
class ITraining:
model: str
importModel: Union[ITrainingImportModel, Dict[str, Any]]
inputs: Union[ITrainingInputs, Dict[str, Any]]
taskUUID: Optional[str] = None
Comment thread
Sirsho1997 marked this conversation as resolved.
deliveryMethod: str = "async"
includeCost: Optional[bool] = None
webhookURL: Optional[str] = None

def __post_init__(self):
if self.deliveryMethod == "sync":
Comment thread
Sirsho1997 marked this conversation as resolved.
raise ValueError("ITraining is a long-running task. Please use 'async' delivery method.")
if self.importModel is not None and isinstance(self.importModel, dict):
self.importModel = ITrainingImportModel(**self.importModel)
if self.inputs is not None and isinstance(self.inputs, dict):
self.inputs = ITrainingInputs(**self.inputs)


@dataclass
class ITrainingResult:
taskType: str
taskUUID: str
status: Optional[str] = None
cost: Optional[float] = None
outputs: Optional[Dict[str, Any]] = None


@dataclass
class IAudioInputs(SerializableMixin):
audio: Optional[str] = None
Expand Down