diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/azure.json b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/azure.json index 7363a41aab..f1ca8d2230 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/azure.json +++ b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/azure.json @@ -61,6 +61,14 @@ "title": "Embedding Batch Size", "default": 5 }, + "max_retries": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Max Retries", + "default": 3, + "description": "The maximum number of times to retry a request if it fails." + }, "timeout": { "type": "number", "minimum": 0, diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/bedrock.json b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/bedrock.json index b71028148c..81829a0f49 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/bedrock.json +++ b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/bedrock.json @@ -43,8 +43,8 @@ "minimum": 0, "multipleOf": 1, "title": "Max Retries", - "default": 5, - "description": "Maximum number of retries to attempt when a request fails." + "default": 3, + "description": "The maximum number of times to retry a request if it fails." }, "timeout": { "type": "number", diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/ollama.json b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/ollama.json index fe292f683c..8dd9bfa1c3 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/ollama.json +++ b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/ollama.json @@ -31,6 +31,14 @@ "multipleOf": 1, "title": "Embed Batch Size", "default": 10 + }, + "max_retries": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Max Retries", + "default": 3, + "description": "The maximum number of times to retry a request if it fails." } } } diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/openai.json b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/openai.json index 9be724e41f..3ad21d3564 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/openai.json +++ b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/openai.json @@ -44,6 +44,14 @@ "title": "Embed Batch Size", "default": 10 }, + "max_retries": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Max Retries", + "default": 3, + "description": "The maximum number of times to retry a request if it fails." + }, "timeout": { "type": "number", "minimum": 0, diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/vertexai.json b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/vertexai.json index 1534fcf93d..6aa48e883f 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/vertexai.json +++ b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/vertexai.json @@ -57,6 +57,14 @@ "retrieval" ], "default": "default" + }, + "max_retries": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Max Retries", + "default": 3, + "description": "The maximum number of times to retry a request if it fails." } } } diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/ollama.json b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/ollama.json index 3800814c77..3c8a4a5f16 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/ollama.json +++ b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/ollama.json @@ -48,6 +48,14 @@ "default": 3900, "description": "The maximum number of context tokens for the model." }, + "max_retries": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Max Retries", + "default": 3, + "description": "The maximum number of times to retry a request if it fails." + }, "request_timeout": { "type": "number", "minimum": 0, diff --git a/unstract/sdk1/src/unstract/sdk1/embedding.py b/unstract/sdk1/src/unstract/sdk1/embedding.py index e54a093393..0d13221617 100644 --- a/unstract/sdk1/src/unstract/sdk1/embedding.py +++ b/unstract/sdk1/src/unstract/sdk1/embedding.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import os from typing import TYPE_CHECKING @@ -14,10 +15,17 @@ from unstract.sdk1.exceptions import SdkError, parse_litellm_err from unstract.sdk1.platform import PlatformHelper from unstract.sdk1.utils.callback_manager import CallbackManager +from unstract.sdk1.utils.retry_utils import ( + acall_with_retry, + call_with_retry, + is_retryable_litellm_error, +) if TYPE_CHECKING: from unstract.sdk1.tool.base import BaseTool +logger = logging.getLogger(__name__) + litellm.drop_params = True @@ -110,14 +118,32 @@ def _get_adapter_info(self) -> str: return f"{self._adapter_name} ({name})" return name + def _pop_retry_params(self, kwargs: dict[str, object]) -> int: + """Extract max_retries and disable litellm's SDK-level retry.""" + max_retries = kwargs.pop("max_retries", None) or 0 + kwargs["max_retries"] = 0 + kwargs["num_retries"] = 0 + logger.debug( + "Embedding: extracted max_retries=%d, " + "disabled litellm retry (max_retries=0, num_retries=0) for %s", + max_retries, + self._get_adapter_info(), + ) + return max_retries + def get_embedding(self, text: str) -> list[float]: """Return embedding vector for query string.""" try: kwargs = self.kwargs.copy() model = kwargs.pop("model") + max_retries = self._pop_retry_params(kwargs) - resp = litellm.embedding(model=model, input=[text], **kwargs) - + resp = call_with_retry( + lambda: litellm.embedding(model=model, input=[text], **kwargs), + max_retries=max_retries, + retry_predicate=is_retryable_litellm_error, + description=self._get_adapter_info(), + ) return resp["data"][0]["embedding"] except Exception as e: raise parse_litellm_err(e, self._get_adapter_info()) from e @@ -127,9 +153,14 @@ def get_embeddings(self, texts: list[str]) -> list[list[float]]: try: kwargs = self.kwargs.copy() model = kwargs.pop("model") + max_retries = self._pop_retry_params(kwargs) - resp = litellm.embedding(model=model, input=texts, **kwargs) - + resp = call_with_retry( + lambda: litellm.embedding(model=model, input=texts, **kwargs), + max_retries=max_retries, + retry_predicate=is_retryable_litellm_error, + description=self._get_adapter_info(), + ) return [data["embedding"] for data in resp["data"]] except Exception as e: raise parse_litellm_err(e, self._get_adapter_info()) from e @@ -139,26 +170,34 @@ async def get_aembedding(self, text: str) -> list[float]: try: kwargs = self.kwargs.copy() model = kwargs.pop("model") + max_retries = self._pop_retry_params(kwargs) - resp = await litellm.aembedding(model=model, input=[text], **kwargs) - + resp = await acall_with_retry( + lambda: litellm.aembedding(model=model, input=[text], **kwargs), + max_retries=max_retries, + retry_predicate=is_retryable_litellm_error, + description=self._get_adapter_info(), + ) return resp["data"][0]["embedding"] except Exception as e: - provider_name = f"{self.adapter.get_name()}" - raise parse_litellm_err(e, provider_name) from e + raise parse_litellm_err(e, self._get_adapter_info()) from e async def get_aembeddings(self, texts: list[str]) -> list[list[float]]: """Return async embedding vectors for list of query strings.""" try: kwargs = self.kwargs.copy() model = kwargs.pop("model") + max_retries = self._pop_retry_params(kwargs) - resp = await litellm.aembedding(model=model, input=texts, **kwargs) - + resp = await acall_with_retry( + lambda: litellm.aembedding(model=model, input=texts, **kwargs), + max_retries=max_retries, + retry_predicate=is_retryable_litellm_error, + description=self._get_adapter_info(), + ) return [data["embedding"] for data in resp["data"]] except Exception as e: - provider_name = f"{self.adapter.get_name()}" - raise parse_litellm_err(e, provider_name) from e + raise parse_litellm_err(e, self._get_adapter_info()) from e def test_connection(self) -> bool: """Test connection to the embedding provider.""" diff --git a/unstract/sdk1/src/unstract/sdk1/llm.py b/unstract/sdk1/src/unstract/sdk1/llm.py index 8ff29a89d5..2f5bbc6242 100644 --- a/unstract/sdk1/src/unstract/sdk1/llm.py +++ b/unstract/sdk1/src/unstract/sdk1/llm.py @@ -24,6 +24,12 @@ TokenCounterCompat, capture_metrics, ) +from unstract.sdk1.utils.retry_utils import ( + acall_with_retry, + call_with_retry, + is_retryable_litellm_error, + iter_with_retry, +) logger = logging.getLogger(__name__) @@ -285,9 +291,12 @@ def complete(self, prompt: str, **kwargs: object) -> dict[str, object]: # if hasattr(self, "thinking_dict") and self.thinking_dict is not None: # completion_kwargs["temperature"] = 1 - response: dict[str, object] = litellm.completion( - messages=messages, - **completion_kwargs, + max_retries = self._disable_litellm_retry(completion_kwargs) + response: dict[str, object] = call_with_retry( + lambda: litellm.completion(messages=messages, **completion_kwargs), + max_retries=max_retries, + retry_predicate=is_retryable_litellm_error, + description=self._get_adapter_info(), ) response_text = response["choices"][0]["message"]["content"] @@ -373,14 +382,18 @@ def stream_complete( completion_kwargs = self.adapter.validate({**self.kwargs, **kwargs}) completion_kwargs.pop("cost_model", None) + max_retries = self._disable_litellm_retry(completion_kwargs) has_yielded_content = False - for chunk in litellm.completion( - messages=messages, - stream=True, - stream_options={ - "include_usage": True, - }, - **completion_kwargs, + for chunk in iter_with_retry( + lambda: litellm.completion( + messages=messages, + stream=True, + stream_options={"include_usage": True}, + **completion_kwargs, + ), + max_retries=max_retries, + retry_predicate=is_retryable_litellm_error, + description=self._get_adapter_info(), ): if chunk.get("usage"): self._record_usage( @@ -437,9 +450,12 @@ async def acomplete(self, prompt: str, **kwargs: object) -> dict[str, object]: completion_kwargs = self.adapter.validate({**self.kwargs, **kwargs}) completion_kwargs.pop("cost_model", None) - response = await litellm.acompletion( - messages=messages, - **completion_kwargs, + max_retries = self._disable_litellm_retry(completion_kwargs) + response = await acall_with_retry( + lambda: litellm.acompletion(messages=messages, **completion_kwargs), + max_retries=max_retries, + retry_predicate=is_retryable_litellm_error, + description=self._get_adapter_info(), ) response_text = response["choices"][0]["message"]["content"] finish_reason = response["choices"][0].get("finish_reason") @@ -532,6 +548,24 @@ def get_metrics(self) -> dict[str, object]: def get_usage_reason(self) -> object: return self.platform_kwargs.get("llm_usage_reason") + @staticmethod + def _disable_litellm_retry(kwargs: dict[str, Any]) -> int: + """Extract max_retries from kwargs and disable litellm's own retry. + + Returns the user-configured max_retries value. + """ + max_retries = kwargs.pop("max_retries", None) or 0 + # Prevent SDK-level retry (OpenAI/Azure native SDK) + kwargs["max_retries"] = 0 + # Prevent litellm wrapper retry (completion_with_retries) + kwargs["num_retries"] = 0 + logger.debug( + "LLM: extracted max_retries=%d, " + "disabled litellm retry (max_retries=0, num_retries=0)", + max_retries, + ) + return max_retries + def _record_usage( self, model: str, diff --git a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py index 9d1a1e4ec1..03392c13e4 100644 --- a/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py +++ b/unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py @@ -1,35 +1,213 @@ """Generic retry utilities with custom exponential backoff implementation.""" +import asyncio +import builtins import errno import logging import os import random import time -from collections.abc import Callable +from collections.abc import Callable, Generator from functools import wraps from typing import Any -from requests.exceptions import ConnectionError, HTTPError, Timeout +from requests.exceptions import ConnectionError as RequestsConnectionError +from requests.exceptions import HTTPError, Timeout logger = logging.getLogger(__name__) +# HTTP status codes that indicate transient server-side failures worth retrying. +RETRYABLE_STATUS_CODES = frozenset({408, 429, 500, 502, 503, 504}) + +# Exception class names (from litellm, openai, httpx) that indicate transient +# connection/timeout failures. Resolved via duck-typing to avoid importing +# litellm in this utility module. +_RETRYABLE_ERROR_NAMES = frozenset( + { + "APIConnectionError", + "APITimeoutError", + "Timeout", + "ConnectTimeout", + "ReadTimeout", + } +) -def is_retryable_error(error: Exception) -> bool: - """Check if an error is retryable. - Handles: - - ConnectionError and Timeout from requests - - HTTPError with status codes 502, 503, 504 - - OSError with specific errno codes (ECONNREFUSED, ECONNRESET, etc.) +def is_retryable_litellm_error(error: Exception) -> bool: + """Check if a litellm/provider API error should trigger a retry. - Args: - error: The exception to check + Distinct from is_retryable_error() which handles requests-library exceptions + (requests.ConnectionError, requests.HTTPError.response.status_code, OSError). + litellm/openai/httpx have a separate exception hierarchy: status_code lives + on the exception itself, and class names like APIConnectionError don't inherit + from the requests types. Uses duck-typing to avoid importing litellm directly. + """ + # Python built-in connection / timeout base classes (not requests.ConnectionError) + if isinstance(error, builtins.ConnectionError | builtins.TimeoutError): + return True - Returns: - True if the error should trigger a retry + # litellm/openai/httpx exception types that don't inherit from the + # built-ins above but still represent transient network failures. + # Check MRO to also catch subclasses of these error types. + if any(cls.__name__ in _RETRYABLE_ERROR_NAMES for cls in type(error).__mro__): + return True + + # Status-code check covers litellm.RateLimitError (429), + # InternalServerError (500), ServiceUnavailableError (503), etc. + status_code = getattr(error, "status_code", None) + if status_code is not None and status_code in RETRYABLE_STATUS_CODES: + return True + + return False + + +# ── Shared retry decision ─────────────────────────────────────────────────── + + +def _get_retry_delay( + error: Exception, + attempt: int, + max_retries: int, + retry_predicate: Callable[[Exception], bool] | None, + description: str, + logger_instance: logging.Logger, + base_delay: float = 1.0, + multiplier: float = 2.0, + max_delay: float = 60.0, + jitter: bool = True, +) -> float | None: + """Decide whether to retry and compute the backoff delay. + + Returns delay in seconds if the error is retryable, None otherwise. + The caller is responsible for sleeping (sync or async) and re-raising + when None is returned. + """ + should_retry = retry_predicate(error) if retry_predicate is not None else True + + logger_instance.debug( + "Retry decision: attempt=%d/%d error=%s retryable=%s description=%s", + attempt + 1, + max_retries + 1, + type(error).__name__, + should_retry, + description, + ) + + if not should_retry or attempt >= max_retries: + return None + + delay = calculate_delay(attempt, base_delay, multiplier, max_delay, jitter) + logger_instance.warning( + "Retry %d/%d for %s: %s (waiting %.1fs)", + attempt + 1, + max_retries, + description, + error, + delay, + ) + return delay + + +# ── Generic retry wrappers ────────────────────────────────────────────────── +# Unlike the decorator-based retry_with_exponential_backoff (env-var configured, +# sync-only), these accept max_retries at call time and support async + generators. +# All delegate retry decisions to _get_retry_delay above. + + +def _validate_max_retries(max_retries: int) -> None: + if max_retries < 0: + raise ValueError(f"max_retries must be >= 0, got {max_retries}") + + +def call_with_retry( + fn: Callable[[], object], + *, + max_retries: int, + retry_predicate: Callable[[Exception], bool], + description: str = "", + logger_instance: logging.Logger | None = None, +) -> object: + """Execute fn() with retry on transient errors.""" + _validate_max_retries(max_retries) + log = logger_instance or logger + for attempt in range(max_retries + 1): + try: + return fn() + except Exception as e: + delay = _get_retry_delay( + e, attempt, max_retries, retry_predicate, description, log + ) + if delay is None: + raise + time.sleep(delay) + + +async def acall_with_retry( + fn: Callable[[], object], + *, + max_retries: int, + retry_predicate: Callable[[Exception], bool], + description: str = "", + logger_instance: logging.Logger | None = None, +) -> object: + """Async version of call_with_retry — awaits fn().""" + _validate_max_retries(max_retries) + log = logger_instance or logger + for attempt in range(max_retries + 1): + try: + return await fn() + except Exception as e: + delay = _get_retry_delay( + e, attempt, max_retries, retry_predicate, description, log + ) + if delay is None: + raise + await asyncio.sleep(delay) + + +def iter_with_retry( + fn: Callable[[], object], + *, + max_retries: int, + retry_predicate: Callable[[Exception], bool], + description: str = "", + logger_instance: logging.Logger | None = None, +) -> Generator: + """Yield from fn() with retry. Only retries before the first yield. + + Once items have been yielded to the caller a mid-iteration failure is + raised immediately — partial output can't be un-yielded. + """ + _validate_max_retries(max_retries) + log = logger_instance or logger + for attempt in range(max_retries + 1): + has_yielded = False + try: + for item in fn(): + has_yielded = True + yield item + return + except Exception as e: + if has_yielded: + raise + delay = _get_retry_delay( + e, attempt, max_retries, retry_predicate, description, log + ) + if delay is None: + raise + time.sleep(delay) + + +def is_retryable_error(error: Exception) -> bool: + """Check if a requests-library HTTP error should trigger a retry. + + For retrying internal service calls (platform-service, prompt-service) that + use the requests library. Distinct from is_retryable_litellm_error() which + handles litellm/openai/httpx exceptions with different class hierarchies + (e.g. error.status_code vs error.response.status_code). """ # Requests connection and timeout errors - if isinstance(error, ConnectionError | Timeout): + if isinstance(error, RequestsConnectionError | Timeout): return True # HTTP errors with specific status codes @@ -85,7 +263,7 @@ def calculate_delay( return min(delay, max_delay) -def retry_with_exponential_backoff( # noqa: C901 +def retry_with_exponential_backoff( max_retries: int, base_delay: float, multiplier: float, @@ -111,38 +289,33 @@ def retry_with_exponential_backoff( # noqa: C901 Decorator function """ - def decorator(func: Callable) -> Callable: # noqa: C901 + def decorator(func: Callable) -> Callable: @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: C901, ANN401 - last_exception = None - - for attempt in range(max_retries + 1): # +1 for initial attempt + def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 + for attempt in range(max_retries + 1): try: - # Try to execute the function result = func(*args, **kwargs) - - # If successful and we had retried, log success if attempt > 0: logger_instance.info( "Successfully completed '%s' after %d retry attempt(s)", func.__name__, attempt, ) - return result - except exceptions as e: - last_exception = e - - # Check if the error should trigger a retry - # First check if it's in the allowed exception types (already caught) - # Then check using the predicate if provided - should_retry = True - if retry_predicate is not None: - should_retry = retry_predicate(e) - - # If not retryable or last attempt, raise the error - if not should_retry or attempt == max_retries: + delay = _get_retry_delay( + e, + attempt, + max_retries, + retry_predicate, + prefix, + logger_instance, + base_delay, + multiplier, + 60.0, + jitter, + ) + if delay is None: if attempt > 0: logger_instance.exception( "Giving up '%s' after %d attempt(s) for %s", @@ -151,32 +324,10 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: C901, ANN401 prefix, ) raise - - # Calculate delay for next retry (capped at 60s) - delay = calculate_delay(attempt, base_delay, multiplier, 60.0, jitter) - - # Log retry attempt - logger_instance.warning( - "Retry %d/%d for %s: %s (waiting %.1fs)", - attempt + 1, - max_retries, - prefix, - e, - delay, - ) - - # Wait before retrying time.sleep(delay) - - except Exception as e: - # Exception not in the exceptions tuple - don't retry - last_exception = e + except Exception: raise - # This should never be reached, but just in case - if last_exception: - raise last_exception - return wrapper return decorator @@ -193,7 +344,7 @@ def create_retry_decorator( Args: prefix: Environment variable prefix for configuration exceptions: Tuple of exception types to retry on. - Defaults to (ConnectionError, HTTPError, Timeout, OSError) + Defaults to (RequestsConnectionError, HTTPError, Timeout, OSError) retry_predicate: Optional callable to determine if exception should trigger retry. If only exceptions list provided, retry on those exceptions. If only predicate provided, use predicate (catch all exceptions). @@ -212,7 +363,7 @@ def create_retry_decorator( # Handle different combinations of exceptions and predicate if exceptions is None and retry_predicate is None: # Default case: use specific exceptions with is_retryable_error predicate - exceptions = (ConnectionError, HTTPError, Timeout, OSError) + exceptions = (RequestsConnectionError, HTTPError, Timeout, OSError) retry_predicate = is_retryable_error elif exceptions is None and retry_predicate is not None: # Only predicate provided: catch all exceptions and use predicate