From ea26be188278428601914075ea1142c0de6f1438 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Tue, 31 Mar 2026 22:58:27 +0530 Subject: [PATCH 01/10] [FIX] Unified retry for LLM and embedding providers litellm's retry only works for SDK-based providers (OpenAI/Azure). httpx-based providers (Anthropic, Vertex, Bedrock, Mistral) and ALL embedding calls silently ignore max_retries. This adds self-managed retry with exponential backoff at the SDK layer, disabling litellm's own retry entirely for consistency. Co-Authored-By: Claude Opus 4.6 (1M context) --- unstract/sdk1/src/unstract/sdk1/embedding.py | 81 ++++++++++- unstract/sdk1/src/unstract/sdk1/llm.py | 137 ++++++++++++++++-- .../src/unstract/sdk1/utils/retry_utils.py | 39 +++++ 3 files changed, 234 insertions(+), 23 deletions(-) diff --git a/unstract/sdk1/src/unstract/sdk1/embedding.py b/unstract/sdk1/src/unstract/sdk1/embedding.py index e54a093393..ccf97a276e 100644 --- a/unstract/sdk1/src/unstract/sdk1/embedding.py +++ b/unstract/sdk1/src/unstract/sdk1/embedding.py @@ -1,6 +1,9 @@ from __future__ import annotations +import asyncio +import logging import os +import time from typing import TYPE_CHECKING import litellm @@ -14,10 +17,13 @@ from unstract.sdk1.exceptions import SdkError, parse_litellm_err from unstract.sdk1.platform import PlatformHelper from unstract.sdk1.utils.callback_manager import CallbackManager +from unstract.sdk1.utils.retry_utils import calculate_delay, is_retryable_litellm_error if TYPE_CHECKING: from unstract.sdk1.tool.base import BaseTool +logger = logging.getLogger(__name__) + litellm.drop_params = True @@ -110,13 +116,70 @@ def _get_adapter_info(self) -> str: return f"{self._adapter_name} ({name})" return name + def _embedding_with_retry( # noqa: ANN401 + self, model: str, input_data: list[str], **kwargs: object + ) -> object: + """Call litellm.embedding with retry on transient errors. + + litellm has no wrapper-level retry for embeddings (unlike completion). + Only SDK-based providers (OpenAI/Azure) honour max_retries natively; + httpx-based providers silently ignore it. This method provides uniform + retry across all providers. + """ + max_retries = kwargs.pop("max_retries", None) or 0 + # Prevent SDK-level retry so we don't double-retry for OpenAI/Azure + kwargs["max_retries"] = 0 + + for attempt in range(max_retries + 1): + try: + return litellm.embedding(model=model, input=input_data, **kwargs) + except Exception as e: + if attempt < max_retries and is_retryable_litellm_error(e): + delay = calculate_delay(attempt, 1.0, 2.0, 60.0) + logger.warning( + "Embedding retry %d/%d for %s: %s (waiting %.1fs)", + attempt + 1, + max_retries, + self._get_adapter_info(), + e, + delay, + ) + time.sleep(delay) + else: + raise + + async def _aembedding_with_retry( # noqa: ANN401 + self, model: str, input_data: list[str], **kwargs: object + ) -> object: + """Async version of _embedding_with_retry.""" + max_retries = kwargs.pop("max_retries", None) or 0 + kwargs["max_retries"] = 0 + + for attempt in range(max_retries + 1): + try: + return await litellm.aembedding(model=model, input=input_data, **kwargs) + except Exception as e: + if attempt < max_retries and is_retryable_litellm_error(e): + delay = calculate_delay(attempt, 1.0, 2.0, 60.0) + logger.warning( + "Embedding retry %d/%d for %s: %s (waiting %.1fs)", + attempt + 1, + max_retries, + self._get_adapter_info(), + e, + delay, + ) + await asyncio.sleep(delay) + else: + raise + def get_embedding(self, text: str) -> list[float]: """Return embedding vector for query string.""" try: kwargs = self.kwargs.copy() model = kwargs.pop("model") - resp = litellm.embedding(model=model, input=[text], **kwargs) + resp = self._embedding_with_retry(model=model, input_data=[text], **kwargs) return resp["data"][0]["embedding"] except Exception as e: @@ -128,7 +191,7 @@ def get_embeddings(self, texts: list[str]) -> list[list[float]]: kwargs = self.kwargs.copy() model = kwargs.pop("model") - resp = litellm.embedding(model=model, input=texts, **kwargs) + resp = self._embedding_with_retry(model=model, input_data=texts, **kwargs) return [data["embedding"] for data in resp["data"]] except Exception as e: @@ -140,12 +203,13 @@ async def get_aembedding(self, text: str) -> list[float]: kwargs = self.kwargs.copy() model = kwargs.pop("model") - resp = await litellm.aembedding(model=model, input=[text], **kwargs) + resp = await self._aembedding_with_retry( + model=model, input_data=[text], **kwargs + ) return resp["data"][0]["embedding"] except Exception as e: - provider_name = f"{self.adapter.get_name()}" - raise parse_litellm_err(e, provider_name) from e + raise parse_litellm_err(e, self._get_adapter_info()) from e async def get_aembeddings(self, texts: list[str]) -> list[list[float]]: """Return async embedding vectors for list of query strings.""" @@ -153,12 +217,13 @@ async def get_aembeddings(self, texts: list[str]) -> list[list[float]]: kwargs = self.kwargs.copy() model = kwargs.pop("model") - resp = await litellm.aembedding(model=model, input=texts, **kwargs) + resp = await self._aembedding_with_retry( + model=model, input_data=texts, **kwargs + ) return [data["embedding"] for data in resp["data"]] except Exception as e: - provider_name = f"{self.adapter.get_name()}" - raise parse_litellm_err(e, provider_name) from e + raise parse_litellm_err(e, self._get_adapter_info()) from e def test_connection(self) -> bool: """Test connection to the embedding provider.""" diff --git a/unstract/sdk1/src/unstract/sdk1/llm.py b/unstract/sdk1/src/unstract/sdk1/llm.py index 8ff29a89d5..37e36da7aa 100644 --- a/unstract/sdk1/src/unstract/sdk1/llm.py +++ b/unstract/sdk1/src/unstract/sdk1/llm.py @@ -1,6 +1,8 @@ +import asyncio import logging import os import re +import time from collections.abc import Callable, Generator, Mapping, Sequence from dataclasses import dataclass, field from enum import Enum @@ -24,6 +26,7 @@ TokenCounterCompat, capture_metrics, ) +from unstract.sdk1.utils.retry_utils import calculate_delay, is_retryable_litellm_error logger = logging.getLogger(__name__) @@ -285,9 +288,8 @@ def complete(self, prompt: str, **kwargs: object) -> dict[str, object]: # if hasattr(self, "thinking_dict") and self.thinking_dict is not None: # completion_kwargs["temperature"] = 1 - response: dict[str, object] = litellm.completion( - messages=messages, - **completion_kwargs, + response: dict[str, object] = self._completion_with_retry( + messages, completion_kwargs ) response_text = response["choices"][0]["message"]["content"] @@ -374,14 +376,7 @@ def stream_complete( completion_kwargs.pop("cost_model", None) has_yielded_content = False - for chunk in litellm.completion( - messages=messages, - stream=True, - stream_options={ - "include_usage": True, - }, - **completion_kwargs, - ): + for chunk in self._stream_completion_with_retry(messages, completion_kwargs): if chunk.get("usage"): self._record_usage( self._cost_model or self.kwargs["model"], @@ -437,10 +432,7 @@ async def acomplete(self, prompt: str, **kwargs: object) -> dict[str, object]: completion_kwargs = self.adapter.validate({**self.kwargs, **kwargs}) completion_kwargs.pop("cost_model", None) - response = await litellm.acompletion( - messages=messages, - **completion_kwargs, - ) + response = await self._acompletion_with_retry(messages, completion_kwargs) response_text = response["choices"][0]["message"]["content"] finish_reason = response["choices"][0].get("finish_reason") @@ -532,6 +524,121 @@ def get_metrics(self) -> dict[str, object]: def get_usage_reason(self) -> object: return self.platform_kwargs.get("llm_usage_reason") + @staticmethod + def _disable_litellm_retry(kwargs: dict[str, Any]) -> int: + """Extract max_retries from kwargs and disable litellm's own retry. + + Returns the user-configured max_retries value. + """ + max_retries = kwargs.pop("max_retries", None) or 0 + # Prevent SDK-level retry (OpenAI/Azure native SDK) + kwargs["max_retries"] = 0 + # Prevent litellm wrapper retry (completion_with_retries) + kwargs["num_retries"] = 0 + return max_retries + + def _completion_with_retry( + self, + messages: list[dict[str, str]], + completion_kwargs: dict[str, Any], + ) -> dict[str, Any]: + """Call litellm.completion with retry on transient errors. + + litellm's wrapper retry (completion_with_retries) only activates via + num_retries and only works for SDK-based providers. This method + provides uniform retry across all providers by managing retries + ourselves and disabling litellm's own retry entirely. + """ + max_retries = self._disable_litellm_retry(completion_kwargs) + + for attempt in range(max_retries + 1): + try: + return litellm.completion(messages=messages, **completion_kwargs) + except Exception as e: + if attempt < max_retries and is_retryable_litellm_error(e): + delay = calculate_delay(attempt, 1.0, 2.0, 60.0) + logger.warning( + "LLM retry %d/%d for %s: %s (waiting %.1fs)", + attempt + 1, + max_retries, + self._get_adapter_info(), + e, + delay, + ) + time.sleep(delay) + else: + raise + + def _stream_completion_with_retry( + self, + messages: list[dict[str, str]], + completion_kwargs: dict[str, Any], + ) -> Generator[dict[str, Any], None, None]: + """Yield raw chunks from litellm.completion(stream=True) with retry. + + Only retries if the error occurs before any chunks have been yielded. + Once content has been yielded to the caller, a mid-stream failure + is raised immediately (partial data can't be un-yielded). + """ + max_retries = self._disable_litellm_retry(completion_kwargs) + + for attempt in range(max_retries + 1): + has_yielded = False + try: + for chunk in litellm.completion( + messages=messages, + stream=True, + stream_options={"include_usage": True}, + **completion_kwargs, + ): + has_yielded = True + yield chunk + return + except Exception as e: + if ( + not has_yielded + and attempt < max_retries + and is_retryable_litellm_error(e) + ): + delay = calculate_delay(attempt, 1.0, 2.0, 60.0) + logger.warning( + "LLM stream retry %d/%d for %s: %s (waiting %.1fs)", + attempt + 1, + max_retries, + self._get_adapter_info(), + e, + delay, + ) + time.sleep(delay) + else: + raise + + async def _acompletion_with_retry( + self, + messages: list[dict[str, str]], + completion_kwargs: dict[str, Any], + ) -> dict[str, Any]: + """Async version of _completion_with_retry.""" + max_retries = self._disable_litellm_retry(completion_kwargs) + + for attempt in range(max_retries + 1): + try: + return await litellm.acompletion(messages=messages, **completion_kwargs) + except Exception as e: + if attempt < max_retries and is_retryable_litellm_error(e): + delay = calculate_delay(attempt, 1.0, 2.0, 60.0) + logger.warning( + "LLM async retry %d/%d for %s: %s (waiting %.1fs)", + attempt + 1, + max_retries, + self._get_adapter_info(), + e, + delay, + ) + await asyncio.sleep(delay) + else: + raise + def _record_usage( self, model: str, diff --git a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py index 9d1a1e4ec1..05319cf2a3 100644 --- a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py +++ b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py @@ -13,6 +13,45 @@ logger = logging.getLogger(__name__) +# HTTP status codes that indicate transient server-side failures worth retrying. +RETRYABLE_STATUS_CODES = frozenset({408, 429, 500, 502, 503, 504}) + +# Exception class names (from litellm, openai, httpx) that indicate transient +# connection/timeout failures. Resolved via duck-typing to avoid importing +# litellm in this utility module. +_RETRYABLE_ERROR_NAMES = frozenset( + { + "APIConnectionError", + "Timeout", + "ConnectTimeout", + "ReadTimeout", + } +) + + +def is_retryable_litellm_error(error: Exception) -> bool: + """Check if a litellm/provider API error should trigger a retry. + + Uses duck-typing (status_code attribute, class name) so this module + doesn't need to import litellm or openai directly. + """ + # Python built-in connection / timeout base classes + if isinstance(error, ConnectionError | TimeoutError): + return True + + # litellm/openai/httpx exception types that don't inherit from the + # built-ins above but still represent transient network failures. + if type(error).__name__ in _RETRYABLE_ERROR_NAMES: + return True + + # Status-code check covers litellm.RateLimitError (429), + # InternalServerError (500), ServiceUnavailableError (503), etc. + status_code = getattr(error, "status_code", None) + if status_code is not None and status_code in RETRYABLE_STATUS_CODES: + return True + + return False + def is_retryable_error(error: Exception) -> bool: """Check if an error is retryable. From cfb22ab32c6379495e03bf7c33cfd5722e5cc04a Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Tue, 31 Mar 2026 23:04:52 +0530 Subject: [PATCH 02/10] [REFACTOR] DRY retry logic into reusable call_with_retry utilities Move retry loops out of LLM/Embedding classes into generic call_with_retry, acall_with_retry, and iter_with_retry functions in retry_utils.py. Both classes now call these directly instead of maintaining their own retry helper methods. Co-Authored-By: Claude Opus 4.6 (1M context) --- unstract/sdk1/src/unstract/sdk1/embedding.py | 101 +++++-------- unstract/sdk1/src/unstract/sdk1/llm.py | 140 ++++-------------- .../src/unstract/sdk1/utils/retry_utils.py | 111 +++++++++++++- 3 files changed, 175 insertions(+), 177 deletions(-) diff --git a/unstract/sdk1/src/unstract/sdk1/embedding.py b/unstract/sdk1/src/unstract/sdk1/embedding.py index ccf97a276e..967a67ee6a 100644 --- a/unstract/sdk1/src/unstract/sdk1/embedding.py +++ b/unstract/sdk1/src/unstract/sdk1/embedding.py @@ -1,9 +1,7 @@ from __future__ import annotations -import asyncio import logging import os -import time from typing import TYPE_CHECKING import litellm @@ -17,7 +15,11 @@ from unstract.sdk1.exceptions import SdkError, parse_litellm_err from unstract.sdk1.platform import PlatformHelper from unstract.sdk1.utils.callback_manager import CallbackManager -from unstract.sdk1.utils.retry_utils import calculate_delay, is_retryable_litellm_error +from unstract.sdk1.utils.retry_utils import ( + acall_with_retry, + call_with_retry, + is_retryable_litellm_error, +) if TYPE_CHECKING: from unstract.sdk1.tool.base import BaseTool @@ -116,71 +118,25 @@ def _get_adapter_info(self) -> str: return f"{self._adapter_name} ({name})" return name - def _embedding_with_retry( # noqa: ANN401 - self, model: str, input_data: list[str], **kwargs: object - ) -> object: - """Call litellm.embedding with retry on transient errors. - - litellm has no wrapper-level retry for embeddings (unlike completion). - Only SDK-based providers (OpenAI/Azure) honour max_retries natively; - httpx-based providers silently ignore it. This method provides uniform - retry across all providers. - """ + def _pop_retry_params(self, kwargs: dict[str, object]) -> int: + """Extract max_retries and disable litellm's SDK-level retry.""" max_retries = kwargs.pop("max_retries", None) or 0 - # Prevent SDK-level retry so we don't double-retry for OpenAI/Azure kwargs["max_retries"] = 0 - - for attempt in range(max_retries + 1): - try: - return litellm.embedding(model=model, input=input_data, **kwargs) - except Exception as e: - if attempt < max_retries and is_retryable_litellm_error(e): - delay = calculate_delay(attempt, 1.0, 2.0, 60.0) - logger.warning( - "Embedding retry %d/%d for %s: %s (waiting %.1fs)", - attempt + 1, - max_retries, - self._get_adapter_info(), - e, - delay, - ) - time.sleep(delay) - else: - raise - - async def _aembedding_with_retry( # noqa: ANN401 - self, model: str, input_data: list[str], **kwargs: object - ) -> object: - """Async version of _embedding_with_retry.""" - max_retries = kwargs.pop("max_retries", None) or 0 - kwargs["max_retries"] = 0 - - for attempt in range(max_retries + 1): - try: - return await litellm.aembedding(model=model, input=input_data, **kwargs) - except Exception as e: - if attempt < max_retries and is_retryable_litellm_error(e): - delay = calculate_delay(attempt, 1.0, 2.0, 60.0) - logger.warning( - "Embedding retry %d/%d for %s: %s (waiting %.1fs)", - attempt + 1, - max_retries, - self._get_adapter_info(), - e, - delay, - ) - await asyncio.sleep(delay) - else: - raise + return max_retries def get_embedding(self, text: str) -> list[float]: """Return embedding vector for query string.""" try: kwargs = self.kwargs.copy() model = kwargs.pop("model") + max_retries = self._pop_retry_params(kwargs) - resp = self._embedding_with_retry(model=model, input_data=[text], **kwargs) - + resp = call_with_retry( + lambda: litellm.embedding(model=model, input=[text], **kwargs), + max_retries=max_retries, + retry_predicate=is_retryable_litellm_error, + description=self._get_adapter_info(), + ) return resp["data"][0]["embedding"] except Exception as e: raise parse_litellm_err(e, self._get_adapter_info()) from e @@ -190,9 +146,14 @@ def get_embeddings(self, texts: list[str]) -> list[list[float]]: try: kwargs = self.kwargs.copy() model = kwargs.pop("model") + max_retries = self._pop_retry_params(kwargs) - resp = self._embedding_with_retry(model=model, input_data=texts, **kwargs) - + resp = call_with_retry( + lambda: litellm.embedding(model=model, input=texts, **kwargs), + max_retries=max_retries, + retry_predicate=is_retryable_litellm_error, + description=self._get_adapter_info(), + ) return [data["embedding"] for data in resp["data"]] except Exception as e: raise parse_litellm_err(e, self._get_adapter_info()) from e @@ -202,11 +163,14 @@ async def get_aembedding(self, text: str) -> list[float]: try: kwargs = self.kwargs.copy() model = kwargs.pop("model") + max_retries = self._pop_retry_params(kwargs) - resp = await self._aembedding_with_retry( - model=model, input_data=[text], **kwargs + resp = await acall_with_retry( + lambda: litellm.aembedding(model=model, input=[text], **kwargs), + max_retries=max_retries, + retry_predicate=is_retryable_litellm_error, + description=self._get_adapter_info(), ) - return resp["data"][0]["embedding"] except Exception as e: raise parse_litellm_err(e, self._get_adapter_info()) from e @@ -216,11 +180,14 @@ async def get_aembeddings(self, texts: list[str]) -> list[list[float]]: try: kwargs = self.kwargs.copy() model = kwargs.pop("model") + max_retries = self._pop_retry_params(kwargs) - resp = await self._aembedding_with_retry( - model=model, input_data=texts, **kwargs + resp = await acall_with_retry( + lambda: litellm.aembedding(model=model, input=texts, **kwargs), + max_retries=max_retries, + retry_predicate=is_retryable_litellm_error, + description=self._get_adapter_info(), ) - return [data["embedding"] for data in resp["data"]] except Exception as e: raise parse_litellm_err(e, self._get_adapter_info()) from e diff --git a/unstract/sdk1/src/unstract/sdk1/llm.py b/unstract/sdk1/src/unstract/sdk1/llm.py index 37e36da7aa..a4c9055de1 100644 --- a/unstract/sdk1/src/unstract/sdk1/llm.py +++ b/unstract/sdk1/src/unstract/sdk1/llm.py @@ -1,8 +1,6 @@ -import asyncio import logging import os import re -import time from collections.abc import Callable, Generator, Mapping, Sequence from dataclasses import dataclass, field from enum import Enum @@ -26,7 +24,12 @@ TokenCounterCompat, capture_metrics, ) -from unstract.sdk1.utils.retry_utils import calculate_delay, is_retryable_litellm_error +from unstract.sdk1.utils.retry_utils import ( + acall_with_retry, + call_with_retry, + is_retryable_litellm_error, + iter_with_retry, +) logger = logging.getLogger(__name__) @@ -288,8 +291,12 @@ def complete(self, prompt: str, **kwargs: object) -> dict[str, object]: # if hasattr(self, "thinking_dict") and self.thinking_dict is not None: # completion_kwargs["temperature"] = 1 - response: dict[str, object] = self._completion_with_retry( - messages, completion_kwargs + max_retries = self._disable_litellm_retry(completion_kwargs) + response: dict[str, object] = call_with_retry( + lambda: litellm.completion(messages=messages, **completion_kwargs), + max_retries=max_retries, + retry_predicate=is_retryable_litellm_error, + description=self._get_adapter_info(), ) response_text = response["choices"][0]["message"]["content"] @@ -375,8 +382,19 @@ def stream_complete( completion_kwargs = self.adapter.validate({**self.kwargs, **kwargs}) completion_kwargs.pop("cost_model", None) + max_retries = self._disable_litellm_retry(completion_kwargs) has_yielded_content = False - for chunk in self._stream_completion_with_retry(messages, completion_kwargs): + for chunk in iter_with_retry( + lambda: litellm.completion( + messages=messages, + stream=True, + stream_options={"include_usage": True}, + **completion_kwargs, + ), + max_retries=max_retries, + retry_predicate=is_retryable_litellm_error, + description=self._get_adapter_info(), + ): if chunk.get("usage"): self._record_usage( self._cost_model or self.kwargs["model"], @@ -432,7 +450,13 @@ async def acomplete(self, prompt: str, **kwargs: object) -> dict[str, object]: completion_kwargs = self.adapter.validate({**self.kwargs, **kwargs}) completion_kwargs.pop("cost_model", None) - response = await self._acompletion_with_retry(messages, completion_kwargs) + max_retries = self._disable_litellm_retry(completion_kwargs) + response = await acall_with_retry( + lambda: litellm.acompletion(messages=messages, **completion_kwargs), + max_retries=max_retries, + retry_predicate=is_retryable_litellm_error, + description=self._get_adapter_info(), + ) response_text = response["choices"][0]["message"]["content"] finish_reason = response["choices"][0].get("finish_reason") @@ -537,108 +561,6 @@ def _disable_litellm_retry(kwargs: dict[str, Any]) -> int: kwargs["num_retries"] = 0 return max_retries - def _completion_with_retry( - self, - messages: list[dict[str, str]], - completion_kwargs: dict[str, Any], - ) -> dict[str, Any]: - """Call litellm.completion with retry on transient errors. - - litellm's wrapper retry (completion_with_retries) only activates via - num_retries and only works for SDK-based providers. This method - provides uniform retry across all providers by managing retries - ourselves and disabling litellm's own retry entirely. - """ - max_retries = self._disable_litellm_retry(completion_kwargs) - - for attempt in range(max_retries + 1): - try: - return litellm.completion(messages=messages, **completion_kwargs) - except Exception as e: - if attempt < max_retries and is_retryable_litellm_error(e): - delay = calculate_delay(attempt, 1.0, 2.0, 60.0) - logger.warning( - "LLM retry %d/%d for %s: %s (waiting %.1fs)", - attempt + 1, - max_retries, - self._get_adapter_info(), - e, - delay, - ) - time.sleep(delay) - else: - raise - - def _stream_completion_with_retry( - self, - messages: list[dict[str, str]], - completion_kwargs: dict[str, Any], - ) -> Generator[dict[str, Any], None, None]: - """Yield raw chunks from litellm.completion(stream=True) with retry. - - Only retries if the error occurs before any chunks have been yielded. - Once content has been yielded to the caller, a mid-stream failure - is raised immediately (partial data can't be un-yielded). - """ - max_retries = self._disable_litellm_retry(completion_kwargs) - - for attempt in range(max_retries + 1): - has_yielded = False - try: - for chunk in litellm.completion( - messages=messages, - stream=True, - stream_options={"include_usage": True}, - **completion_kwargs, - ): - has_yielded = True - yield chunk - return - except Exception as e: - if ( - not has_yielded - and attempt < max_retries - and is_retryable_litellm_error(e) - ): - delay = calculate_delay(attempt, 1.0, 2.0, 60.0) - logger.warning( - "LLM stream retry %d/%d for %s: %s (waiting %.1fs)", - attempt + 1, - max_retries, - self._get_adapter_info(), - e, - delay, - ) - time.sleep(delay) - else: - raise - - async def _acompletion_with_retry( - self, - messages: list[dict[str, str]], - completion_kwargs: dict[str, Any], - ) -> dict[str, Any]: - """Async version of _completion_with_retry.""" - max_retries = self._disable_litellm_retry(completion_kwargs) - - for attempt in range(max_retries + 1): - try: - return await litellm.acompletion(messages=messages, **completion_kwargs) - except Exception as e: - if attempt < max_retries and is_retryable_litellm_error(e): - delay = calculate_delay(attempt, 1.0, 2.0, 60.0) - logger.warning( - "LLM async retry %d/%d for %s: %s (waiting %.1fs)", - attempt + 1, - max_retries, - self._get_adapter_info(), - e, - delay, - ) - await asyncio.sleep(delay) - else: - raise - def _record_usage( self, model: str, diff --git a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py index 05319cf2a3..37322cb9d5 100644 --- a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py +++ b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py @@ -1,11 +1,12 @@ """Generic retry utilities with custom exponential backoff implementation.""" +import asyncio import errno import logging import os import random import time -from collections.abc import Callable +from collections.abc import Callable, Generator from functools import wraps from typing import Any @@ -53,6 +54,114 @@ def is_retryable_litellm_error(error: Exception) -> bool: return False +# ── Generic retry wrappers ────────────────────────────────────────────────── +# Unlike the decorator-based retry_with_exponential_backoff (env-var configured, +# sync-only), these accept max_retries at call time and support async + generators. + + +def call_with_retry( + fn: Callable[[], object], + *, + max_retries: int, + retry_predicate: Callable[[Exception], bool], + description: str = "", + logger_instance: logging.Logger | None = None, +) -> object: + """Execute fn() with retry on transient errors. + + Args: + fn: Zero-arg callable to invoke (use a lambda to bind args). + max_retries: Maximum retry attempts (0 = no retry). + retry_predicate: Returns True if the exception should trigger a retry. + description: Label for log messages (e.g. adapter name). + logger_instance: Logger; defaults to module logger. + """ + log = logger_instance or logger + for attempt in range(max_retries + 1): + try: + return fn() + except Exception as e: + if attempt < max_retries and retry_predicate(e): + delay = calculate_delay(attempt, 1.0, 2.0, 60.0) + log.warning( + "Retry %d/%d for %s: %s (waiting %.1fs)", + attempt + 1, + max_retries, + description, + e, + delay, + ) + time.sleep(delay) + else: + raise + + +async def acall_with_retry( + fn: Callable[[], object], + *, + max_retries: int, + retry_predicate: Callable[[Exception], bool], + description: str = "", + logger_instance: logging.Logger | None = None, +) -> object: + """Async version of call_with_retry — awaits fn().""" + log = logger_instance or logger + for attempt in range(max_retries + 1): + try: + return await fn() + except Exception as e: + if attempt < max_retries and retry_predicate(e): + delay = calculate_delay(attempt, 1.0, 2.0, 60.0) + log.warning( + "Retry %d/%d for %s: %s (waiting %.1fs)", + attempt + 1, + max_retries, + description, + e, + delay, + ) + await asyncio.sleep(delay) + else: + raise + + +def iter_with_retry( + fn: Callable[[], object], + *, + max_retries: int, + retry_predicate: Callable[[Exception], bool], + description: str = "", + logger_instance: logging.Logger | None = None, +) -> Generator: + """Yield from fn() with retry. Only retries before the first yield. + + Once items have been yielded to the caller a mid-iteration failure is + raised immediately — partial output can't be un-yielded. + """ + log = logger_instance or logger + for attempt in range(max_retries + 1): + has_yielded = False + try: + for item in fn(): + has_yielded = True + yield item + return + except Exception as e: + if not has_yielded and attempt < max_retries and retry_predicate(e): + delay = calculate_delay(attempt, 1.0, 2.0, 60.0) + log.warning( + "Retry %d/%d for %s: %s (waiting %.1fs)", + attempt + 1, + max_retries, + description, + e, + delay, + ) + time.sleep(delay) + else: + raise + + def is_retryable_error(error: Exception) -> bool: """Check if an error is retryable. From ffbb5ffc20ed82d6b79d8be7fd0621ab876eeb14 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Wed, 1 Apr 2026 16:07:51 +0530 Subject: [PATCH 03/10] [FIX] Consolidate retry logic, expose max_retries for all adapters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extract _get_retry_delay() shared helper to eliminate duplicated retry decision logic across call_with_retry, acall_with_retry, iter_with_retry, and retry_with_exponential_backoff - Add num_retries=0 to embedding._pop_retry_params() to fully disable litellm's internal retry for embedding calls - Expose max_retries in UI JSON schemas for embedding adapters (OpenAI, Azure, VertexAI, Ollama) and Ollama LLM — previously the field existed in Pydantic models but wasn't shown to users, silently defaulting to 0 retries - Add debug logging to LLM and Embedding retry parameter extraction - Clarify docstrings distinguishing is_retryable_litellm_error() from is_retryable_error() (different exception hierarchies) - Remove stale noqa: C901 from simplified retry_with_exponential_backoff Co-Authored-By: Claude Opus 4.6 (1M context) --- .../adapters/embedding1/static/azure.json | 8 + .../adapters/embedding1/static/ollama.json | 8 + .../adapters/embedding1/static/openai.json | 8 + .../adapters/embedding1/static/vertexai.json | 8 + .../sdk1/adapters/llm1/static/ollama.json | 8 + unstract/sdk1/src/unstract/sdk1/embedding.py | 7 + unstract/sdk1/src/unstract/sdk1/llm.py | 5 + .../src/unstract/sdk1/utils/retry_utils.py | 197 +++++++++--------- 8 files changed, 146 insertions(+), 103 deletions(-) diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/azure.json b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/azure.json index 7363a41aab..33c93a8277 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/azure.json +++ b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/azure.json @@ -61,6 +61,14 @@ "title": "Embedding Batch Size", "default": 5 }, + "max_retries": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Max Retries", + "default": 5, + "description": "The maximum number of times to retry a request if it fails." + }, "timeout": { "type": "number", "minimum": 0, diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/ollama.json b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/ollama.json index fe292f683c..63cf865976 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/ollama.json +++ b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/ollama.json @@ -31,6 +31,14 @@ "multipleOf": 1, "title": "Embed Batch Size", "default": 10 + }, + "max_retries": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Max Retries", + "default": 5, + "description": "The maximum number of times to retry a request if it fails." } } } diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/openai.json b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/openai.json index 9be724e41f..aec2fa3648 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/openai.json +++ b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/openai.json @@ -44,6 +44,14 @@ "title": "Embed Batch Size", "default": 10 }, + "max_retries": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Max Retries", + "default": 5, + "description": "The maximum number of times to retry a request if it fails." + }, "timeout": { "type": "number", "minimum": 0, diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/vertexai.json b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/vertexai.json index 1534fcf93d..14563fec1f 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/vertexai.json +++ b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/vertexai.json @@ -57,6 +57,14 @@ "retrieval" ], "default": "default" + }, + "max_retries": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Max Retries", + "default": 5, + "description": "The maximum number of times to retry a request if it fails." } } } diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/ollama.json b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/ollama.json index 3800814c77..a321baaa63 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/ollama.json +++ b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/ollama.json @@ -48,6 +48,14 @@ "default": 3900, "description": "The maximum number of context tokens for the model." }, + "max_retries": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Max Retries", + "default": 5, + "description": "The maximum number of times to retry a request if it fails." + }, "request_timeout": { "type": "number", "minimum": 0, diff --git a/unstract/sdk1/src/unstract/sdk1/embedding.py b/unstract/sdk1/src/unstract/sdk1/embedding.py index 967a67ee6a..0d13221617 100644 --- a/unstract/sdk1/src/unstract/sdk1/embedding.py +++ b/unstract/sdk1/src/unstract/sdk1/embedding.py @@ -122,6 +122,13 @@ def _pop_retry_params(self, kwargs: dict[str, object]) -> int: """Extract max_retries and disable litellm's SDK-level retry.""" max_retries = kwargs.pop("max_retries", None) or 0 kwargs["max_retries"] = 0 + kwargs["num_retries"] = 0 + logger.debug( + "Embedding: extracted max_retries=%d, " + "disabled litellm retry (max_retries=0, num_retries=0) for %s", + max_retries, + self._get_adapter_info(), + ) return max_retries def get_embedding(self, text: str) -> list[float]: diff --git a/unstract/sdk1/src/unstract/sdk1/llm.py b/unstract/sdk1/src/unstract/sdk1/llm.py index a4c9055de1..2f5bbc6242 100644 --- a/unstract/sdk1/src/unstract/sdk1/llm.py +++ b/unstract/sdk1/src/unstract/sdk1/llm.py @@ -559,6 +559,11 @@ def _disable_litellm_retry(kwargs: dict[str, Any]) -> int: kwargs["max_retries"] = 0 # Prevent litellm wrapper retry (completion_with_retries) kwargs["num_retries"] = 0 + logger.debug( + "LLM: extracted max_retries=%d, " + "disabled litellm retry (max_retries=0, num_retries=0)", + max_retries, + ) return max_retries def _record_usage( diff --git a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py index 37322cb9d5..cccbcb67c2 100644 --- a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py +++ b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py @@ -33,8 +33,11 @@ def is_retryable_litellm_error(error: Exception) -> bool: """Check if a litellm/provider API error should trigger a retry. - Uses duck-typing (status_code attribute, class name) so this module - doesn't need to import litellm or openai directly. + Distinct from is_retryable_error() which handles requests-library exceptions + (requests.ConnectionError, requests.HTTPError.response.status_code, OSError). + litellm/openai/httpx have a separate exception hierarchy: status_code lives + on the exception itself, and class names like APIConnectionError don't inherit + from the requests types. Uses duck-typing to avoid importing litellm directly. """ # Python built-in connection / timeout base classes if isinstance(error, ConnectionError | TimeoutError): @@ -54,9 +57,57 @@ def is_retryable_litellm_error(error: Exception) -> bool: return False +# ── Shared retry decision ─────────────────────────────────────────────────── + + +def _get_retry_delay( + error: Exception, + attempt: int, + max_retries: int, + retry_predicate: Callable[[Exception], bool] | None, + description: str, + logger_instance: logging.Logger, + base_delay: float = 1.0, + multiplier: float = 2.0, + max_delay: float = 60.0, + jitter: bool = True, +) -> float | None: + """Decide whether to retry and compute the backoff delay. + + Returns delay in seconds if the error is retryable, None otherwise. + The caller is responsible for sleeping (sync or async) and re-raising + when None is returned. + """ + should_retry = retry_predicate(error) if retry_predicate is not None else True + + logger_instance.debug( + "Retry decision: attempt=%d/%d error=%s retryable=%s description=%s", + attempt + 1, + max_retries + 1, + type(error).__name__, + should_retry, + description, + ) + + if not should_retry or attempt >= max_retries: + return None + + delay = calculate_delay(attempt, base_delay, multiplier, max_delay, jitter) + logger_instance.warning( + "Retry %d/%d for %s: %s (waiting %.1fs)", + attempt + 1, + max_retries, + description, + error, + delay, + ) + return delay + + # ── Generic retry wrappers ────────────────────────────────────────────────── # Unlike the decorator-based retry_with_exponential_backoff (env-var configured, # sync-only), these accept max_retries at call time and support async + generators. +# All delegate retry decisions to _get_retry_delay above. def call_with_retry( @@ -67,33 +118,18 @@ def call_with_retry( description: str = "", logger_instance: logging.Logger | None = None, ) -> object: - """Execute fn() with retry on transient errors. - - Args: - fn: Zero-arg callable to invoke (use a lambda to bind args). - max_retries: Maximum retry attempts (0 = no retry). - retry_predicate: Returns True if the exception should trigger a retry. - description: Label for log messages (e.g. adapter name). - logger_instance: Logger; defaults to module logger. - """ + """Execute fn() with retry on transient errors.""" log = logger_instance or logger for attempt in range(max_retries + 1): try: return fn() except Exception as e: - if attempt < max_retries and retry_predicate(e): - delay = calculate_delay(attempt, 1.0, 2.0, 60.0) - log.warning( - "Retry %d/%d for %s: %s (waiting %.1fs)", - attempt + 1, - max_retries, - description, - e, - delay, - ) - time.sleep(delay) - else: + delay = _get_retry_delay( + e, attempt, max_retries, retry_predicate, description, log + ) + if delay is None: raise + time.sleep(delay) async def acall_with_retry( @@ -110,19 +146,12 @@ async def acall_with_retry( try: return await fn() except Exception as e: - if attempt < max_retries and retry_predicate(e): - delay = calculate_delay(attempt, 1.0, 2.0, 60.0) - log.warning( - "Retry %d/%d for %s: %s (waiting %.1fs)", - attempt + 1, - max_retries, - description, - e, - delay, - ) - await asyncio.sleep(delay) - else: + delay = _get_retry_delay( + e, attempt, max_retries, retry_predicate, description, log + ) + if delay is None: raise + await asyncio.sleep(delay) def iter_with_retry( @@ -147,34 +176,23 @@ def iter_with_retry( yield item return except Exception as e: - if not has_yielded and attempt < max_retries and retry_predicate(e): - delay = calculate_delay(attempt, 1.0, 2.0, 60.0) - log.warning( - "Retry %d/%d for %s: %s (waiting %.1fs)", - attempt + 1, - max_retries, - description, - e, - delay, - ) - time.sleep(delay) - else: + if has_yielded: + raise + delay = _get_retry_delay( + e, attempt, max_retries, retry_predicate, description, log + ) + if delay is None: raise + time.sleep(delay) def is_retryable_error(error: Exception) -> bool: - """Check if an error is retryable. - - Handles: - - ConnectionError and Timeout from requests - - HTTPError with status codes 502, 503, 504 - - OSError with specific errno codes (ECONNREFUSED, ECONNRESET, etc.) - - Args: - error: The exception to check + """Check if a requests-library HTTP error should trigger a retry. - Returns: - True if the error should trigger a retry + For retrying internal service calls (platform-service, prompt-service) that + use the requests library. Distinct from is_retryable_litellm_error() which + handles litellm/openai/httpx exceptions with different class hierarchies + (e.g. error.status_code vs error.response.status_code). """ # Requests connection and timeout errors if isinstance(error, ConnectionError | Timeout): @@ -233,7 +251,7 @@ def calculate_delay( return min(delay, max_delay) -def retry_with_exponential_backoff( # noqa: C901 +def retry_with_exponential_backoff( max_retries: int, base_delay: float, multiplier: float, @@ -259,38 +277,33 @@ def retry_with_exponential_backoff( # noqa: C901 Decorator function """ - def decorator(func: Callable) -> Callable: # noqa: C901 + def decorator(func: Callable) -> Callable: @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: C901, ANN401 - last_exception = None - - for attempt in range(max_retries + 1): # +1 for initial attempt + def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 + for attempt in range(max_retries + 1): try: - # Try to execute the function result = func(*args, **kwargs) - - # If successful and we had retried, log success if attempt > 0: logger_instance.info( "Successfully completed '%s' after %d retry attempt(s)", func.__name__, attempt, ) - return result - except exceptions as e: - last_exception = e - - # Check if the error should trigger a retry - # First check if it's in the allowed exception types (already caught) - # Then check using the predicate if provided - should_retry = True - if retry_predicate is not None: - should_retry = retry_predicate(e) - - # If not retryable or last attempt, raise the error - if not should_retry or attempt == max_retries: + delay = _get_retry_delay( + e, + attempt, + max_retries, + retry_predicate, + prefix, + logger_instance, + base_delay, + multiplier, + 60.0, + jitter, + ) + if delay is None: if attempt > 0: logger_instance.exception( "Giving up '%s' after %d attempt(s) for %s", @@ -299,32 +312,10 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: C901, ANN401 prefix, ) raise - - # Calculate delay for next retry (capped at 60s) - delay = calculate_delay(attempt, base_delay, multiplier, 60.0, jitter) - - # Log retry attempt - logger_instance.warning( - "Retry %d/%d for %s: %s (waiting %.1fs)", - attempt + 1, - max_retries, - prefix, - e, - delay, - ) - - # Wait before retrying time.sleep(delay) - - except Exception as e: - # Exception not in the exceptions tuple - don't retry - last_exception = e + except Exception: raise - # This should never be reached, but just in case - if last_exception: - raise last_exception - return wrapper return decorator From f83404aeaaadb898ee4383130ac2eaa9238faa85 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Wed, 1 Apr 2026 18:44:24 +0530 Subject: [PATCH 04/10] [FIX] Set max_retries default to 3 for all embedding and Ollama LLM adapters Co-Authored-By: Claude Opus 4.6 (1M context) --- .../src/unstract/sdk1/adapters/embedding1/static/azure.json | 2 +- .../src/unstract/sdk1/adapters/embedding1/static/bedrock.json | 4 ++-- .../src/unstract/sdk1/adapters/embedding1/static/ollama.json | 2 +- .../src/unstract/sdk1/adapters/embedding1/static/openai.json | 2 +- .../unstract/sdk1/adapters/embedding1/static/vertexai.json | 2 +- .../sdk1/src/unstract/sdk1/adapters/llm1/static/ollama.json | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/azure.json b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/azure.json index 33c93a8277..f1ca8d2230 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/azure.json +++ b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/azure.json @@ -66,7 +66,7 @@ "minimum": 0, "multipleOf": 1, "title": "Max Retries", - "default": 5, + "default": 3, "description": "The maximum number of times to retry a request if it fails." }, "timeout": { diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/bedrock.json b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/bedrock.json index b71028148c..81829a0f49 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/bedrock.json +++ b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/bedrock.json @@ -43,8 +43,8 @@ "minimum": 0, "multipleOf": 1, "title": "Max Retries", - "default": 5, - "description": "Maximum number of retries to attempt when a request fails." + "default": 3, + "description": "The maximum number of times to retry a request if it fails." }, "timeout": { "type": "number", diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/ollama.json b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/ollama.json index 63cf865976..8dd9bfa1c3 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/ollama.json +++ b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/ollama.json @@ -37,7 +37,7 @@ "minimum": 0, "multipleOf": 1, "title": "Max Retries", - "default": 5, + "default": 3, "description": "The maximum number of times to retry a request if it fails." } } diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/openai.json b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/openai.json index aec2fa3648..3ad21d3564 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/openai.json +++ b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/openai.json @@ -49,7 +49,7 @@ "minimum": 0, "multipleOf": 1, "title": "Max Retries", - "default": 5, + "default": 3, "description": "The maximum number of times to retry a request if it fails." }, "timeout": { diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/vertexai.json b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/vertexai.json index 14563fec1f..6aa48e883f 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/vertexai.json +++ b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/vertexai.json @@ -63,7 +63,7 @@ "minimum": 0, "multipleOf": 1, "title": "Max Retries", - "default": 5, + "default": 3, "description": "The maximum number of times to retry a request if it fails." } } diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/ollama.json b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/ollama.json index a321baaa63..3c8a4a5f16 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/ollama.json +++ b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/ollama.json @@ -53,7 +53,7 @@ "minimum": 0, "multipleOf": 1, "title": "Max Retries", - "default": 5, + "default": 3, "description": "The maximum number of times to retry a request if it fails." }, "request_timeout": { From 43fed188202b08dd9744fa8b523708b7bbfe39f3 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Wed, 1 Apr 2026 18:54:52 +0530 Subject: [PATCH 05/10] [FIX] Address greptile review: fix shadowed ConnectionError, use MRO check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix `requests.ConnectionError` shadowing Python's builtin `ConnectionError` in `is_retryable_litellm_error()` — rename import to `RequestsConnectionError` and use `builtins.ConnectionError` / `builtins.TimeoutError` explicitly - Use `__mro__`-based class name check instead of `type(error).__name__` to also catch subclasses of retryable error types - P1 (num_retries not zeroed) was already fixed in prior commit Co-Authored-By: Claude Opus 4.6 (1M context) --- .../sdk1/src/unstract/sdk1/utils/retry_utils.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py index cccbcb67c2..18000b5452 100644 --- a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py +++ b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py @@ -1,6 +1,7 @@ """Generic retry utilities with custom exponential backoff implementation.""" import asyncio +import builtins import errno import logging import os @@ -10,7 +11,8 @@ from functools import wraps from typing import Any -from requests.exceptions import ConnectionError, HTTPError, Timeout +from requests.exceptions import ConnectionError as RequestsConnectionError +from requests.exceptions import HTTPError, Timeout logger = logging.getLogger(__name__) @@ -39,13 +41,14 @@ def is_retryable_litellm_error(error: Exception) -> bool: on the exception itself, and class names like APIConnectionError don't inherit from the requests types. Uses duck-typing to avoid importing litellm directly. """ - # Python built-in connection / timeout base classes - if isinstance(error, ConnectionError | TimeoutError): + # Python built-in connection / timeout base classes (not requests.ConnectionError) + if isinstance(error, builtins.ConnectionError | builtins.TimeoutError): return True # litellm/openai/httpx exception types that don't inherit from the # built-ins above but still represent transient network failures. - if type(error).__name__ in _RETRYABLE_ERROR_NAMES: + # Check MRO to also catch subclasses of these error types. + if any(cls.__name__ in _RETRYABLE_ERROR_NAMES for cls in type(error).__mro__): return True # Status-code check covers litellm.RateLimitError (429), @@ -195,7 +198,7 @@ def is_retryable_error(error: Exception) -> bool: (e.g. error.status_code vs error.response.status_code). """ # Requests connection and timeout errors - if isinstance(error, ConnectionError | Timeout): + if isinstance(error, RequestsConnectionError | Timeout): return True # HTTP errors with specific status codes @@ -332,7 +335,7 @@ def create_retry_decorator( Args: prefix: Environment variable prefix for configuration exceptions: Tuple of exception types to retry on. - Defaults to (ConnectionError, HTTPError, Timeout, OSError) + Defaults to (RequestsConnectionError, HTTPError, Timeout, OSError) retry_predicate: Optional callable to determine if exception should trigger retry. If only exceptions list provided, retry on those exceptions. If only predicate provided, use predicate (catch all exceptions). @@ -351,7 +354,7 @@ def create_retry_decorator( # Handle different combinations of exceptions and predicate if exceptions is None and retry_predicate is None: # Default case: use specific exceptions with is_retryable_error predicate - exceptions = (ConnectionError, HTTPError, Timeout, OSError) + exceptions = (RequestsConnectionError, HTTPError, Timeout, OSError) retry_predicate = is_retryable_error elif exceptions is None and retry_predicate is not None: # Only predicate provided: catch all exceptions and use predicate From c8054e87a986828506e3198d5a1f9f71dfc3b40e Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Wed, 1 Apr 2026 21:00:19 +0530 Subject: [PATCH 06/10] [FIX] Address CodeRabbit review: add APITimeoutError, validate max_retries - Add APITimeoutError to _RETRYABLE_ERROR_NAMES for explicit OpenAI SDK timeout coverage - Add _validate_max_retries() guard to call_with_retry, acall_with_retry, iter_with_retry to fail fast on negative values instead of silently returning None Co-Authored-By: Claude Opus 4.6 (1M context) --- unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py index 18000b5452..03392c13e4 100644 --- a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py +++ b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py @@ -25,6 +25,7 @@ _RETRYABLE_ERROR_NAMES = frozenset( { "APIConnectionError", + "APITimeoutError", "Timeout", "ConnectTimeout", "ReadTimeout", @@ -113,6 +114,11 @@ def _get_retry_delay( # All delegate retry decisions to _get_retry_delay above. +def _validate_max_retries(max_retries: int) -> None: + if max_retries < 0: + raise ValueError(f"max_retries must be >= 0, got {max_retries}") + + def call_with_retry( fn: Callable[[], object], *, @@ -122,6 +128,7 @@ def call_with_retry( logger_instance: logging.Logger | None = None, ) -> object: """Execute fn() with retry on transient errors.""" + _validate_max_retries(max_retries) log = logger_instance or logger for attempt in range(max_retries + 1): try: @@ -144,6 +151,7 @@ async def acall_with_retry( logger_instance: logging.Logger | None = None, ) -> object: """Async version of call_with_retry — awaits fn().""" + _validate_max_retries(max_retries) log = logger_instance or logger for attempt in range(max_retries + 1): try: @@ -170,6 +178,7 @@ def iter_with_retry( Once items have been yielded to the caller a mid-iteration failure is raised immediately — partial output can't be un-yielded. """ + _validate_max_retries(max_retries) log = logger_instance or logger for attempt in range(max_retries + 1): has_yielded = False From f186fb0854f21648253754b1eea0fd0877619059 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Fri, 17 Apr 2026 11:58:38 +0530 Subject: [PATCH 07/10] UN-3344 [FIX] Reduce cognitive complexity and remove useless except clause MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address SonarCloud findings on PR #1886: - S3776: Flatten retry_with_exponential_backoff.wrapper by moving the success logging + return out of the try block and using `continue` in the retry path, so the except branch only handles the give-up case. - S2737: Drop the `except Exception: raise` clause — it was a no-op that added complexity without changing behavior (non-matching exceptions propagate naturally). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../src/unstract/sdk1/utils/retry_utils.py | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py index 03392c13e4..bb5e9a9747 100644 --- a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py +++ b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py @@ -295,13 +295,6 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 for attempt in range(max_retries + 1): try: result = func(*args, **kwargs) - if attempt > 0: - logger_instance.info( - "Successfully completed '%s' after %d retry attempt(s)", - func.__name__, - attempt, - ) - return result except exceptions as e: delay = _get_retry_delay( e, @@ -315,18 +308,24 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 60.0, jitter, ) - if delay is None: - if attempt > 0: - logger_instance.exception( - "Giving up '%s' after %d attempt(s) for %s", - func.__name__, - attempt + 1, - prefix, - ) - raise - time.sleep(delay) - except Exception: + if delay is not None: + time.sleep(delay) + continue + if attempt > 0: + logger_instance.exception( + "Giving up '%s' after %d attempt(s) for %s", + func.__name__, + attempt + 1, + prefix, + ) raise + if attempt > 0: + logger_instance.info( + "Successfully completed '%s' after %d retry attempt(s)", + func.__name__, + attempt, + ) + return result return wrapper From d0c864f4fac3e82af7aaff2b995c9cd390006461 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Fri, 17 Apr 2026 13:44:41 +0530 Subject: [PATCH 08/10] UN-3344 [FIX] Extract retry loop to top-level helper to drop cognitive complexity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Sonar still flagged retry_with_exponential_backoff at complexity 16 after the previous flatten. Nested def decorator / def wrapper counted against the outer function's score. Move the retry body to a module-level _invoke_with_retries helper so the decorator factory just delegates, bringing the outer function well under the 15 threshold. Behavior is unchanged — all paths (success, retry, give-up, non-retryable propagate) are preserved and covered by the existing SDK1 tests. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../src/unstract/sdk1/utils/retry_utils.py | 102 ++++++++++++------ 1 file changed, 68 insertions(+), 34 deletions(-) diff --git a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py index bb5e9a9747..b7f5b94929 100644 --- a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py +++ b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py @@ -263,6 +263,61 @@ def calculate_delay( return min(delay, max_delay) +def _invoke_with_retries( + func: Callable, + args: tuple, + kwargs: dict, + *, + max_retries: int, + base_delay: float, + multiplier: float, + jitter: bool, + exceptions: tuple[type[Exception], ...], + logger_instance: logging.Logger, + prefix: str, + retry_predicate: Callable[[Exception], bool] | None, +) -> Any: # noqa: ANN401 + """Execute func with exponential-backoff retries. + + See retry_with_exponential_backoff for parameter semantics. + """ + for attempt in range(max_retries + 1): + try: + result = func(*args, **kwargs) + except exceptions as e: + delay = _get_retry_delay( + e, + attempt, + max_retries, + retry_predicate, + prefix, + logger_instance, + base_delay, + multiplier, + 60.0, + jitter, + ) + if delay is not None: + time.sleep(delay) + continue + if attempt > 0: + logger_instance.exception( + "Giving up '%s' after %d attempt(s) for %s", + func.__name__, + attempt + 1, + prefix, + ) + raise + if attempt > 0: + logger_instance.info( + "Successfully completed '%s' after %d retry attempt(s)", + func.__name__, + attempt, + ) + return result + return None # unreachable: range(max_retries + 1) is non-empty + + def retry_with_exponential_backoff( max_retries: int, base_delay: float, @@ -292,40 +347,19 @@ def retry_with_exponential_backoff( def decorator(func: Callable) -> Callable: @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 - for attempt in range(max_retries + 1): - try: - result = func(*args, **kwargs) - except exceptions as e: - delay = _get_retry_delay( - e, - attempt, - max_retries, - retry_predicate, - prefix, - logger_instance, - base_delay, - multiplier, - 60.0, - jitter, - ) - if delay is not None: - time.sleep(delay) - continue - if attempt > 0: - logger_instance.exception( - "Giving up '%s' after %d attempt(s) for %s", - func.__name__, - attempt + 1, - prefix, - ) - raise - if attempt > 0: - logger_instance.info( - "Successfully completed '%s' after %d retry attempt(s)", - func.__name__, - attempt, - ) - return result + return _invoke_with_retries( + func, + args, + kwargs, + max_retries=max_retries, + base_delay=base_delay, + multiplier=multiplier, + jitter=jitter, + exceptions=exceptions, + logger_instance=logger_instance, + prefix=prefix, + retry_predicate=retry_predicate, + ) return wrapper From 68927b5482fe51072bd0af48b21fd06ee3bb8bdd Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Fri, 17 Apr 2026 15:12:31 +0530 Subject: [PATCH 09/10] UN-3344 [FIX] Honor Retry-After, close stream gen on retry, share give-up log MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address review comments on PR #1886: - #10 (resource leak): close the generator returned by fn() before retrying in iter_with_retry — otherwise streaming providers leak an in-flight HTTP socket until GC. - #12 (behavioral regression): when we zero out SDK/wrapper retries we also lose the OpenAI SDK's native Retry-After handling on 429/503. _get_retry_delay now checks error.response.headers["retry-after"] and uses that value ahead of exponential backoff. HTTP-date form is not parsed; those fall back to backoff. - #8 (observability gap): move the "Giving up ... after N attempt(s)" log into _get_retry_delay so all four retry helpers (call_with_retry, acall_with_retry, iter_with_retry, decorator) share the same exhaustion signal. Previously only the decorator path logged it. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../src/unstract/sdk1/utils/retry_utils.py | 59 ++++++++++++++++--- 1 file changed, 50 insertions(+), 9 deletions(-) diff --git a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py index b7f5b94929..5aa13df729 100644 --- a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py +++ b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py @@ -64,6 +64,29 @@ def is_retryable_litellm_error(error: Exception) -> bool: # ── Shared retry decision ─────────────────────────────────────────────────── +def _extract_retry_after(error: Exception) -> float | None: + """Return the server-supplied Retry-After delay in seconds, if present. + + Honors the provider's explicit cool-down hint on 429/503 responses so our + backoff doesn't hammer the provider before its requested wait. Only the + integer/float seconds form is supported; RFC 7231 HTTP-date values fall + back to exponential backoff. + """ + response = getattr(error, "response", None) + if response is None: + return None + headers = getattr(response, "headers", None) + if headers is None: + return None + value = headers.get("retry-after") or headers.get("Retry-After") + if value is None: + return None + try: + return max(float(value), 0.0) + except (TypeError, ValueError): + return None + + def _get_retry_delay( error: Exception, attempt: int, @@ -94,9 +117,25 @@ def _get_retry_delay( ) if not should_retry or attempt >= max_retries: + # Shared exhaustion log — fires for every retry helper once retries + # were actually attempted (attempt > 0) and the error was retryable + # (i.e. we stopped because we ran out of attempts, not because the + # error type was non-retryable). + if attempt > 0 and should_retry: + logger_instance.exception( + "Giving up %s after %d attempt(s)", + description, + attempt + 1, + ) return None - delay = calculate_delay(attempt, base_delay, multiplier, max_delay, jitter) + # Provider-supplied Retry-After (e.g. 429/503) wins over our exponential + # backoff — matches the behavior the OpenAI/Azure SDKs give natively. + retry_after = _extract_retry_after(error) + if retry_after is not None: + delay = retry_after + else: + delay = calculate_delay(attempt, base_delay, multiplier, max_delay, jitter) logger_instance.warning( "Retry %d/%d for %s: %s (waiting %.1fs)", attempt + 1, @@ -182,12 +221,19 @@ def iter_with_retry( log = logger_instance or logger for attempt in range(max_retries + 1): has_yielded = False + gen = fn() try: - for item in fn(): + for item in gen: has_yielded = True yield item return except Exception as e: + # Close generator to release in-flight HTTP/socket resources + # before retrying — otherwise streaming providers leak sockets + # until GC. + close = getattr(gen, "close", None) + if callable(close): + close() if has_yielded: raise delay = _get_retry_delay( @@ -300,13 +346,8 @@ def _invoke_with_retries( if delay is not None: time.sleep(delay) continue - if attempt > 0: - logger_instance.exception( - "Giving up '%s' after %d attempt(s) for %s", - func.__name__, - attempt + 1, - prefix, - ) + # Give-up log is emitted inside _get_retry_delay so all retry + # helpers share the same exhaustion signal. raise if attempt > 0: logger_instance.info( From a8453e1b8026f7b0149bf578077ec13e5f7f3d62 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Fri, 17 Apr 2026 15:29:26 +0530 Subject: [PATCH 10/10] UN-3344 [REFACTOR] Share retry-kwargs helper and add TypeVar to retry wrappers Address review comments on PR #1886: - #9 (typing): call_with_retry / acall_with_retry / iter_with_retry previously returned `object`, erasing caller type info. Add PEP 695 generics so the return type flows from the wrapped callable: acall_with_retry now takes Callable[[], Awaitable[T]] and iter_with_retry takes Callable[[], Iterable[T]] -> Generator[T, ...]. - #11 / #13 (DRY): `_pop_retry_params` in embedding.py and `_disable_litellm_retry` in llm.py were identical logic. Lift to shared `pop_litellm_retry_kwargs` helper in retry_utils.py and delete both methods. Co-Authored-By: Claude Opus 4.7 (1M context) --- unstract/sdk1/src/unstract/sdk1/embedding.py | 22 ++------ unstract/sdk1/src/unstract/sdk1/llm.py | 31 ++++-------- .../src/unstract/sdk1/utils/retry_utils.py | 50 +++++++++++++++---- 3 files changed, 55 insertions(+), 48 deletions(-) diff --git a/unstract/sdk1/src/unstract/sdk1/embedding.py b/unstract/sdk1/src/unstract/sdk1/embedding.py index 2629538e4b..4e30c6201c 100644 --- a/unstract/sdk1/src/unstract/sdk1/embedding.py +++ b/unstract/sdk1/src/unstract/sdk1/embedding.py @@ -19,6 +19,7 @@ acall_with_retry, call_with_retry, is_retryable_litellm_error, + pop_litellm_retry_kwargs, ) if TYPE_CHECKING: @@ -118,25 +119,12 @@ def _get_adapter_info(self) -> str: return f"{self._adapter_name} ({name})" return name - def _pop_retry_params(self, kwargs: dict[str, object]) -> int: - """Extract max_retries and disable litellm's SDK-level retry.""" - max_retries = kwargs.pop("max_retries", None) or 0 - kwargs["max_retries"] = 0 - kwargs["num_retries"] = 0 - logger.debug( - "Embedding: extracted max_retries=%d, " - "disabled litellm retry (max_retries=0, num_retries=0) for %s", - max_retries, - self._get_adapter_info(), - ) - return max_retries - def get_embedding(self, text: str) -> list[float]: """Return embedding vector for query string.""" try: kwargs = self.kwargs.copy() model = kwargs.pop("model") - max_retries = self._pop_retry_params(kwargs) + max_retries = pop_litellm_retry_kwargs(kwargs, self._get_adapter_info()) resp = call_with_retry( lambda: litellm.embedding(model=model, input=[text], **kwargs), @@ -153,7 +141,7 @@ def get_embeddings(self, texts: list[str]) -> list[list[float]]: try: kwargs = self.kwargs.copy() model = kwargs.pop("model") - max_retries = self._pop_retry_params(kwargs) + max_retries = pop_litellm_retry_kwargs(kwargs, self._get_adapter_info()) resp = call_with_retry( lambda: litellm.embedding(model=model, input=texts, **kwargs), @@ -170,7 +158,7 @@ async def get_aembedding(self, text: str) -> list[float]: try: kwargs = self.kwargs.copy() model = kwargs.pop("model") - max_retries = self._pop_retry_params(kwargs) + max_retries = pop_litellm_retry_kwargs(kwargs, self._get_adapter_info()) resp = await acall_with_retry( lambda: litellm.aembedding(model=model, input=[text], **kwargs), @@ -187,7 +175,7 @@ async def get_aembeddings(self, texts: list[str]) -> list[list[float]]: try: kwargs = self.kwargs.copy() model = kwargs.pop("model") - max_retries = self._pop_retry_params(kwargs) + max_retries = pop_litellm_retry_kwargs(kwargs, self._get_adapter_info()) resp = await acall_with_retry( lambda: litellm.aembedding(model=model, input=texts, **kwargs), diff --git a/unstract/sdk1/src/unstract/sdk1/llm.py b/unstract/sdk1/src/unstract/sdk1/llm.py index 2f5bbc6242..c1730e6613 100644 --- a/unstract/sdk1/src/unstract/sdk1/llm.py +++ b/unstract/sdk1/src/unstract/sdk1/llm.py @@ -29,6 +29,7 @@ call_with_retry, is_retryable_litellm_error, iter_with_retry, + pop_litellm_retry_kwargs, ) logger = logging.getLogger(__name__) @@ -291,7 +292,9 @@ def complete(self, prompt: str, **kwargs: object) -> dict[str, object]: # if hasattr(self, "thinking_dict") and self.thinking_dict is not None: # completion_kwargs["temperature"] = 1 - max_retries = self._disable_litellm_retry(completion_kwargs) + max_retries = pop_litellm_retry_kwargs( + completion_kwargs, self._get_adapter_info() + ) response: dict[str, object] = call_with_retry( lambda: litellm.completion(messages=messages, **completion_kwargs), max_retries=max_retries, @@ -382,7 +385,9 @@ def stream_complete( completion_kwargs = self.adapter.validate({**self.kwargs, **kwargs}) completion_kwargs.pop("cost_model", None) - max_retries = self._disable_litellm_retry(completion_kwargs) + max_retries = pop_litellm_retry_kwargs( + completion_kwargs, self._get_adapter_info() + ) has_yielded_content = False for chunk in iter_with_retry( lambda: litellm.completion( @@ -450,7 +455,9 @@ async def acomplete(self, prompt: str, **kwargs: object) -> dict[str, object]: completion_kwargs = self.adapter.validate({**self.kwargs, **kwargs}) completion_kwargs.pop("cost_model", None) - max_retries = self._disable_litellm_retry(completion_kwargs) + max_retries = pop_litellm_retry_kwargs( + completion_kwargs, self._get_adapter_info() + ) response = await acall_with_retry( lambda: litellm.acompletion(messages=messages, **completion_kwargs), max_retries=max_retries, @@ -548,24 +555,6 @@ def get_metrics(self) -> dict[str, object]: def get_usage_reason(self) -> object: return self.platform_kwargs.get("llm_usage_reason") - @staticmethod - def _disable_litellm_retry(kwargs: dict[str, Any]) -> int: - """Extract max_retries from kwargs and disable litellm's own retry. - - Returns the user-configured max_retries value. - """ - max_retries = kwargs.pop("max_retries", None) or 0 - # Prevent SDK-level retry (OpenAI/Azure native SDK) - kwargs["max_retries"] = 0 - # Prevent litellm wrapper retry (completion_with_retries) - kwargs["num_retries"] = 0 - logger.debug( - "LLM: extracted max_retries=%d, " - "disabled litellm retry (max_retries=0, num_retries=0)", - max_retries, - ) - return max_retries - def _record_usage( self, model: str, diff --git a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py index 5aa13df729..0a13331474 100644 --- a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py +++ b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py @@ -7,7 +7,7 @@ import os import random import time -from collections.abc import Callable, Generator +from collections.abc import Awaitable, Callable, Generator, Iterable from functools import wraps from typing import Any @@ -158,14 +158,42 @@ def _validate_max_retries(max_retries: int) -> None: raise ValueError(f"max_retries must be >= 0, got {max_retries}") -def call_with_retry( - fn: Callable[[], object], +def pop_litellm_retry_kwargs(kwargs: dict[str, Any], context: str = "") -> int: + """Pop max_retries from kwargs and disable litellm's built-in retries. + + litellm has two separate retry mechanisms: + - max_retries: passed to the SDK client (OpenAI/Azure) as its + constructor arg — triggers SDK-level retries. + - num_retries: activates litellm's own completion_with_retries wrapper. + + Both are zeroed so the outer retry helpers (call_with_retry etc.) are + the single source of truth. Note that num_retries=0 is dropped from + embedding kwargs by litellm.drop_params=True, but setting it keeps the + intent explicit and consistent across LLM/embedding paths. + + Returns the user-configured max_retries value (or 0 if unset). + """ + max_retries = kwargs.pop("max_retries", None) or 0 + kwargs["max_retries"] = 0 + kwargs["num_retries"] = 0 + suffix = f" for {context}" if context else "" + logger.debug( + "Extracted max_retries=%d, disabled litellm retry " + "(max_retries=0, num_retries=0)%s", + max_retries, + suffix, + ) + return max_retries + + +def call_with_retry[T]( + fn: Callable[[], T], *, max_retries: int, retry_predicate: Callable[[Exception], bool], description: str = "", logger_instance: logging.Logger | None = None, -) -> object: +) -> T: """Execute fn() with retry on transient errors.""" _validate_max_retries(max_retries) log = logger_instance or logger @@ -179,16 +207,17 @@ def call_with_retry( if delay is None: raise time.sleep(delay) + raise RuntimeError("unreachable") # for type-checker: loop always returns or raises -async def acall_with_retry( - fn: Callable[[], object], +async def acall_with_retry[T]( + fn: Callable[[], Awaitable[T]], *, max_retries: int, retry_predicate: Callable[[Exception], bool], description: str = "", logger_instance: logging.Logger | None = None, -) -> object: +) -> T: """Async version of call_with_retry — awaits fn().""" _validate_max_retries(max_retries) log = logger_instance or logger @@ -202,16 +231,17 @@ async def acall_with_retry( if delay is None: raise await asyncio.sleep(delay) + raise RuntimeError("unreachable") # for type-checker: loop always returns or raises -def iter_with_retry( - fn: Callable[[], object], +def iter_with_retry[T]( + fn: Callable[[], Iterable[T]], *, max_retries: int, retry_predicate: Callable[[Exception], bool], description: str = "", logger_instance: logging.Logger | None = None, -) -> Generator: +) -> Generator[T, None, None]: """Yield from fn() with retry. Only retries before the first yield. Once items have been yielded to the caller a mid-iteration failure is