From c664d4e9099439d5759954e11ac544912d81ed10 Mon Sep 17 00:00:00 2001 From: Sirshendu Ganguly Date: Thu, 30 Apr 2026 17:56:28 +0530 Subject: [PATCH 1/5] Added support for taskType training --- runware/base.py | 143 +++++++++++++++++++++++++++++++++++++++++++++-- runware/types.py | 54 ++++++++++++++++++ 2 files changed, 193 insertions(+), 4 deletions(-) diff --git a/runware/base.py b/runware/base.py index 499838d..fd460eb 100644 --- a/runware/base.py +++ b/runware/base.py @@ -61,6 +61,8 @@ OperationState, I3dInference, I3d, + ITraining, + ITrainingResult, IGetResponseRequest, IGetTaskDetailsRequest, IUploadImageRequest, @@ -2169,6 +2171,14 @@ async def _inference3d(self, request3d: I3dInference) -> Union[List[I3d], IAsync await self.ensureConnection() return await self._request3d(request3d) + async def training(self, requestTraining: ITraining) -> Union[List[ITrainingResult], IAsyncTaskResponse]: + async with self._request_semaphore: + return await self._retry_async_with_reconnect( + self._requestTraining, + requestTraining, + task_type=ETaskType.TRAINING.value, + ) + async def textInference( self, requestText: ITextInference ) -> Union[List[IText], IAsyncTaskResponse, AsyncIterator[Union[str, IText]]]: @@ -2194,7 +2204,7 @@ async def getResponse( self, taskUUID: str, numberResults: Optional[int] = 1, - ) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d], List[IText]]: + ) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d], List[IText], List[ITrainingResult]]: async with self._request_semaphore: request = IGetResponseRequest( taskUUID=taskUUID, @@ -2205,7 +2215,7 @@ async def getResponse( async def _getResponse( self, request_model: IGetResponseRequest, - ) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d], List[IText]]: + ) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d], List[IText], List[ITrainingResult]]: await self.ensureConnection() return await self._pollResults( @@ -2284,6 +2294,7 @@ def _normalizeTaskDetailsRequest(self, request_items: List[Any]) -> List[Any]: ETaskType.VIDEO_UPSCALE.value: IVideoUpscale, ETaskType.AUDIO_INFERENCE.value: IAudioInference, ETaskType.INFERENCE_3D.value: I3dInference, + ETaskType.TRAINING.value: ITraining, ETaskType.TEXT_INFERENCE.value: ITextInference, ETaskType.GET_RESPONSE.value: IGetResponseRequest, ETaskType.GET_TASK_DETAILS.value: IGetTaskDetailsRequest, @@ -2314,6 +2325,7 @@ def _normalizeTaskDetailsResponse(self, response_payload: Any) -> Any: ETaskType.VIDEO_BACKGROUND_REMOVAL.value: IVideo, ETaskType.VIDEO_UPSCALE.value: IVideo, ETaskType.INFERENCE_3D.value: I3d, + ETaskType.TRAINING.value: ITrainingResult, ETaskType.TEXT_INFERENCE.value: IText, ETaskType.PROMPT_ENHANCE.value: IEnhancedPrompt, ETaskType.GET_TASK_DETAILS.value: ITaskDetails, @@ -2529,6 +2541,43 @@ async def _request3d(self, request3d: I3dInference) -> Union[List[I3d], IAsyncTa debug_key="3d-inference-initial", ) + async def _requestTraining(self, requestTraining: ITraining) -> Union[List[ITrainingResult], IAsyncTaskResponse]: + await self.ensureConnection() + requestTraining.taskUUID = requestTraining.taskUUID or getUUID() + + if requestTraining.inputs.dataset: + requestTraining.inputs.dataset = await self._process_media(requestTraining.inputs.dataset) + + request_object = self._buildTrainingRequest(requestTraining) + + return await self._handleInitialTrainingResponse( + request_object=request_object, + task_uuid=requestTraining.taskUUID, + number_results=requestTraining.numberResults or 1, + delivery_method=requestTraining.deliveryMethod, + webhook_url=requestTraining.webhookURL, + debug_key="training-initial", + ) + + def _buildTrainingRequest(self, requestTraining: ITraining) -> Dict[str, Any]: + request_object: Dict[str, Any] = { + "taskType": ETaskType.TRAINING.value, + "taskUUID": requestTraining.taskUUID, + "model": requestTraining.model, + "deliveryMethod": requestTraining.deliveryMethod, + } + + if requestTraining.numberResults is not None: + request_object["numberResults"] = requestTraining.numberResults + if requestTraining.includeCost is not None: + request_object["includeCost"] = requestTraining.includeCost + if requestTraining.webhookURL is not None: + request_object["webhookURL"] = requestTraining.webhookURL + + self._addOptionalField(request_object, requestTraining.importModel) + self._addOptionalField(request_object, requestTraining.inputs) + return request_object + def _buildTextRequest(self, requestText: ITextInference) -> Dict[str, Any]: request_object: Dict[str, Any] = { "taskType": ETaskType.TEXT_INFERENCE.value, @@ -3161,6 +3210,85 @@ def is_3d_complete(r: Dict[str, Any]) -> bool: finally: await self._unregister_pending_operation(task_uuid) + async def _handleInitialTrainingResponse( + self, + request_object: Dict[str, Any], + task_uuid: str, + number_results: int, + delivery_method: Union[str, EDeliveryMethod] = EDeliveryMethod.ASYNC, + webhook_url: Optional[str] = None, + debug_key: str = "training-initial", + ) -> Union[List[ITrainingResult], IAsyncTaskResponse]: + delivery_method_enum = ( + delivery_method + if isinstance(delivery_method, EDeliveryMethod) + else EDeliveryMethod(delivery_method) + ) + + def is_training_complete(r: Dict[str, Any]) -> bool: + if r.get("status") == "success": + return True + if r.get("outputs") is not None: + return True + if webhook_url or delivery_method_enum is EDeliveryMethod.ASYNC: + return True + return False + + if delivery_method_enum is EDeliveryMethod.SYNC and not webhook_url: + future, should_send = await self._register_pending_operation( + task_uuid, + expected_results=number_results or 1, + complete_predicate=None, + result_filter=lambda r: ( + r.get("status") == "success" + or r.get("outputs") is not None + ), + ) + else: + future, should_send = await self._register_pending_operation( + task_uuid, + expected_results=1, + complete_predicate=is_training_complete, + ) + + timeout = TIMEOUT_DURATION if delivery_method_enum is EDeliveryMethod.SYNC else VIDEO_INITIAL_TIMEOUT + + try: + if should_send: + await self.send([request_object]) + await self._mark_operation_sent(task_uuid) + + results = await asyncio.wait_for(future, timeout=timeout / 1000) + if not results: + raise ConnectionError( + f"No initial response received for training | " + f"delivery_method={delivery_method_enum} | taskUUID={task_uuid}" + ) + + response = results[0] + self._handle_error_response(response) + + if response.get("status") == "success" or response.get("outputs") is not None: + return instantiateDataclassList(ITrainingResult, results) + + if webhook_url or delivery_method_enum is EDeliveryMethod.ASYNC: + return createAsyncTaskResponse(response) + + return instantiateDataclassList(ITrainingResult, results) + + except asyncio.TimeoutError: + if not self.connected() or not self.isWebsocketReadyState(): + raise ConnectionError( + f"Connection lost while waiting for training response | " + f"TaskUUID: {task_uuid} | Delivery method: {delivery_method_enum}" + ) + raise ConnectionError( + f"Timeout waiting for training response | TaskUUID: {task_uuid} | " + f"Timeout: {timeout}ms" + ) + finally: + await self._unregister_pending_operation(task_uuid) + async def _handleInitialImageResponse( self, task_uuid: str, @@ -3422,7 +3550,7 @@ async def _pollResults( self, task_uuid: str, number_results: Optional[int], - ) -> Union[List[IVideo], List[IVideoToText], List[IAudio], List[IImage], List[I3d], List[IText]]: + ) -> Union[List[IVideo], List[IVideoToText], List[IAudio], List[IImage], List[I3d], List[IText], List[ITrainingResult]]: # Default to 1 if number_results is None if number_results is None: number_results = 1 @@ -3430,7 +3558,7 @@ async def _pollResults( completed_results: "List[Dict[str, Any]]" = [] task_type = None - response_cls: Optional[Union[IVideo, IVideoToText, IAudio, IImage, I3d, IText]] = None + response_cls: Optional[Union[IVideo, IVideoToText, IAudio, IImage, I3d, IText, ITrainingResult]] = None max_polls: int = MAX_POLLS polling_delay: int = VIDEO_POLLING_DELAY timeout_message: str = f"Polling timeout after {MAX_POLLS} polls" @@ -3497,6 +3625,13 @@ def configure_from_task_type(task_type_val: Optional[str]): TEXT_POLLING_DELAY, f"Text generation timeout after {MAX_POLLS} polls" ) + case ETaskType.TRAINING.value: + return ( + ITrainingResult, + MAX_POLLS, + VIDEO_POLLING_DELAY, + f"Training timeout after {MAX_POLLS} polls" + ) case _: raise ValueError(f"Unsupported task type for polling: {task_type_val}") diff --git a/runware/types.py b/runware/types.py index 76a62fa..13d489a 100644 --- a/runware/types.py +++ b/runware/types.py @@ -50,6 +50,7 @@ class ETaskType(Enum): GET_RESPONSE = "getResponse" GET_TASK_DETAILS = "getTaskDetails" IMAGE_VECTORIZE = "vectorize" + TRAINING = "training" class EPreProcessorGroup(Enum): @@ -1239,6 +1240,32 @@ def request_key(self) -> str: return "inputs" +@dataclass +class ITrainingImportModel(SerializableMixin): + air: str + name: str + uniqueIdentifier: str + version: str + private: bool + heroImageURL: Optional[str] = None + shortDescription: Optional[str] = None + architecture: Optional[str] = None + category: Optional[str] = None + + @property + def request_key(self) -> str: + return "importModel" + + +@dataclass +class ITrainingInputs(SerializableMixin): + dataset: Optional[Union[str, File]] = None + + @property + def request_key(self) -> str: + return "inputs" + + @dataclass class IImageInference: model: Union[int, str] @@ -1944,6 +1971,33 @@ def __post_init__(self): self.inputs = I3dInputs(**self.inputs) +@dataclass +class ITraining: + model: str + importModel: Union[ITrainingImportModel, Dict[str, Any]] + inputs: Union[ITrainingInputs, Dict[str, Any]] + taskUUID: Optional[str] = None + deliveryMethod: str = "async" + includeCost: Optional[bool] = None + numberResults: Optional[int] = 1 + webhookURL: Optional[str] = None + + def __post_init__(self): + if self.importModel is not None and isinstance(self.importModel, dict): + self.importModel = ITrainingImportModel(**self.importModel) + if self.inputs is not None and isinstance(self.inputs, dict): + self.inputs = ITrainingInputs(**self.inputs) + + +@dataclass +class ITrainingResult: + taskType: str + taskUUID: str + status: Optional[str] = None + cost: Optional[float] = None + outputs: Optional[Dict[str, Any]] = None + + @dataclass class IAudioInputs(SerializableMixin): audio: Optional[str] = None From 2569b5e6d64eb05d15cb75ba36eb11989ef2b86c Mon Sep 17 00:00:00 2001 From: Sirshendu Ganguly Date: Tue, 5 May 2026 17:06:54 +0530 Subject: [PATCH 2/5] Made ITraining async only --- README.md | 62 ++++++++++++++++++++++++++++++++++++++++++++++++ runware/types.py | 2 ++ 2 files changed, 64 insertions(+) diff --git a/README.md b/README.md index 02f7555..2652702 100644 --- a/README.md +++ b/README.md @@ -241,6 +241,68 @@ async def main() -> None: - Use `getResponse(taskUUID)` to retrieve results at any time - `deliveryMethod="sync"` waits for complete results (may timeout for long-running tasks) +### Training (Async Only) + +Training is a long-running task type. +- `deliveryMethod="sync"` is not supported for training and raises a `ValueError` +- Use `deliveryMethod="async"` and retrieve final results with `getResponse(taskUUID)` + +```python +import uuid +from runware import Runware, ITraining, ITrainingImportModel, ITrainingInputs + +async def main() -> None: + runware = Runware(api_key=RUNWARE_API_KEY) + await runware.connect() + + request_training = ITraining( + taskUUID=str(uuid.uuid4()), + taskType="training", + model="runware:illustrative@training", + deliveryMethod="async", + importModel=ITrainingImportModel( + air="runware:illustrative@0", + name="Runware Illustrative Training Model", + uniqueIdentifier="exacltyai_illustrative_model_1", + version="1.0", + private=False, + heroImageURL="https://example.com/hero-image.png", + shortDescription="First training model", + ), + inputs=ITrainingInputs( + dataset="example/pictures.zip" + ), + ) + + training_task = await runware.training(requestTraining=request_training) + results = await runware.getResponse(taskUUID=training_task.taskUUID) + print(results) +``` + +After training, you can run image inference with the trained model: + +```python +import uuid +from runware import Runware, IImageInference, IInputs, IInputReference + +async def main() -> None: + runware = Runware(api_key=RUNWARE_API_KEY) + await runware.connect() + + request_image = IImageInference( + taskUUID=str(uuid.uuid4()), + model="runware:illustrative@0", + positivePrompt="a horse", + numberResults=1, + width=1024, + height=1025, + deliveryMethod="sync", + ) + + image_task = await runware.imageInference(requestImage=request_image) + print(f"Image inference task submitted: {image_task.taskUUID}") +``` + ### Retrieving Original Task Request/Response To inspect the original request payload and response for a past task, use `getTaskDetails(taskUUID)`. diff --git a/runware/types.py b/runware/types.py index 13d489a..f8b518b 100644 --- a/runware/types.py +++ b/runware/types.py @@ -1983,6 +1983,8 @@ class ITraining: webhookURL: Optional[str] = None def __post_init__(self): + if self.deliveryMethod == "sync": + raise ValueError("ITraining is a long-running task. Please use 'async' delivery method.") if self.importModel is not None and isinstance(self.importModel, dict): self.importModel = ITrainingImportModel(**self.importModel) if self.inputs is not None and isinstance(self.inputs, dict): From b93641e4a25d70fe98cc0511c975b11ad964a030 Mon Sep 17 00:00:00 2001 From: Sirshendu Ganguly Date: Tue, 5 May 2026 17:41:43 +0530 Subject: [PATCH 3/5] Fixed Readme.md --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 2652702..2c4b907 100644 --- a/README.md +++ b/README.md @@ -257,7 +257,6 @@ async def main() -> None: request_training = ITraining( taskUUID=str(uuid.uuid4()), - taskType="training", model="runware:illustrative@training", deliveryMethod="async", importModel=ITrainingImportModel( From 1002cdf19d79b3aa83122631f40fed9d7f58eb65 Mon Sep 17 00:00:00 2001 From: Sirshendu Ganguly Date: Tue, 5 May 2026 18:50:01 +0530 Subject: [PATCH 4/5] Fixed Readme.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2c4b907..9688470 100644 --- a/README.md +++ b/README.md @@ -262,7 +262,7 @@ async def main() -> None: importModel=ITrainingImportModel( air="runware:illustrative@0", name="Runware Illustrative Training Model", - uniqueIdentifier="exacltyai_illustrative_model_1", + uniqueIdentifier="runware_illustrative_model_1", version="1.0", private=False, heroImageURL="https://example.com/hero-image.png", From 946871651e88b4bd97d875fa61d8c6a4311f4f96 Mon Sep 17 00:00:00 2001 From: Sirshendu Ganguly Date: Wed, 6 May 2026 12:10:11 +0530 Subject: [PATCH 5/5] Removed numberResults --- runware/base.py | 4 +--- runware/types.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/runware/base.py b/runware/base.py index fd460eb..d4f2f76 100644 --- a/runware/base.py +++ b/runware/base.py @@ -2553,7 +2553,7 @@ async def _requestTraining(self, requestTraining: ITraining) -> Union[List[ITrai return await self._handleInitialTrainingResponse( request_object=request_object, task_uuid=requestTraining.taskUUID, - number_results=requestTraining.numberResults or 1, + number_results=1, delivery_method=requestTraining.deliveryMethod, webhook_url=requestTraining.webhookURL, debug_key="training-initial", @@ -2567,8 +2567,6 @@ def _buildTrainingRequest(self, requestTraining: ITraining) -> Dict[str, Any]: "deliveryMethod": requestTraining.deliveryMethod, } - if requestTraining.numberResults is not None: - request_object["numberResults"] = requestTraining.numberResults if requestTraining.includeCost is not None: request_object["includeCost"] = requestTraining.includeCost if requestTraining.webhookURL is not None: diff --git a/runware/types.py b/runware/types.py index f8b518b..17c93a6 100644 --- a/runware/types.py +++ b/runware/types.py @@ -1979,7 +1979,6 @@ class ITraining: taskUUID: Optional[str] = None deliveryMethod: str = "async" includeCost: Optional[bool] = None - numberResults: Optional[int] = 1 webhookURL: Optional[str] = None def __post_init__(self):