Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@
"title": "Embedding Batch Size",
"default": 5
},
"max_retries": {
"type": "number",
"minimum": 0,
"multipleOf": 1,
"title": "Max Retries",
"default": 3,
"description": "The maximum number of times to retry a request if it fails."
},
"timeout": {
"type": "number",
"minimum": 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
"minimum": 0,
"multipleOf": 1,
"title": "Max Retries",
"default": 5,
"description": "Maximum number of retries to attempt when a request fails."
"default": 3,
"description": "The maximum number of times to retry a request if it fails."
},
"timeout": {
"type": "number",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@
"multipleOf": 1,
"title": "Embed Batch Size",
"default": 10
},
"max_retries": {
"type": "number",
"minimum": 0,
"multipleOf": 1,
"title": "Max Retries",
"default": 3,
"description": "The maximum number of times to retry a request if it fails."
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@
"title": "Embed Batch Size",
"default": 10
},
"max_retries": {
"type": "number",
"minimum": 0,
"multipleOf": 1,
"title": "Max Retries",
"default": 3,
"description": "The maximum number of times to retry a request if it fails."
},
"timeout": {
"type": "number",
"minimum": 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@
"retrieval"
],
"default": "default"
},
"max_retries": {
"type": "number",
"minimum": 0,
"multipleOf": 1,
"title": "Max Retries",
"default": 3,
"description": "The maximum number of times to retry a request if it fails."
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@
"default": 3900,
"description": "The maximum number of context tokens for the model."
},
"max_retries": {
"type": "number",
"minimum": 0,
"multipleOf": 1,
"title": "Max Retries",
"default": 3,
"description": "The maximum number of times to retry a request if it fails."
},
"request_timeout": {
"type": "number",
"minimum": 0,
Expand Down
63 changes: 51 additions & 12 deletions unstract/sdk1/src/unstract/sdk1/embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import os
from typing import TYPE_CHECKING

Expand All @@ -14,10 +15,17 @@
from unstract.sdk1.exceptions import SdkError, parse_litellm_err
from unstract.sdk1.platform import PlatformHelper
from unstract.sdk1.utils.callback_manager import CallbackManager
from unstract.sdk1.utils.retry_utils import (
acall_with_retry,
call_with_retry,
is_retryable_litellm_error,
)

if TYPE_CHECKING:
from unstract.sdk1.tool.base import BaseTool

logger = logging.getLogger(__name__)

litellm.drop_params = True


Expand Down Expand Up @@ -110,14 +118,32 @@ def _get_adapter_info(self) -> str:
return f"{self._adapter_name} ({name})"
return name

def _pop_retry_params(self, kwargs: dict[str, object]) -> int:
"""Extract max_retries and disable litellm's SDK-level retry."""
max_retries = kwargs.pop("max_retries", None) or 0
kwargs["max_retries"] = 0
kwargs["num_retries"] = 0
logger.debug(
"Embedding: extracted max_retries=%d, "
"disabled litellm retry (max_retries=0, num_retries=0) for %s",
max_retries,
self._get_adapter_info(),
)
return max_retries
Comment thread
greptile-apps[bot] marked this conversation as resolved.

def get_embedding(self, text: str) -> list[float]:
"""Return embedding vector for query string."""
try:
kwargs = self.kwargs.copy()
model = kwargs.pop("model")
max_retries = self._pop_retry_params(kwargs)

resp = litellm.embedding(model=model, input=[text], **kwargs)

resp = call_with_retry(
lambda: litellm.embedding(model=model, input=[text], **kwargs),
max_retries=max_retries,
retry_predicate=is_retryable_litellm_error,
description=self._get_adapter_info(),
)
return resp["data"][0]["embedding"]
except Exception as e:
raise parse_litellm_err(e, self._get_adapter_info()) from e
Expand All @@ -127,9 +153,14 @@ def get_embeddings(self, texts: list[str]) -> list[list[float]]:
try:
kwargs = self.kwargs.copy()
model = kwargs.pop("model")
max_retries = self._pop_retry_params(kwargs)

resp = litellm.embedding(model=model, input=texts, **kwargs)

resp = call_with_retry(
lambda: litellm.embedding(model=model, input=texts, **kwargs),
max_retries=max_retries,
retry_predicate=is_retryable_litellm_error,
description=self._get_adapter_info(),
)
return [data["embedding"] for data in resp["data"]]
except Exception as e:
raise parse_litellm_err(e, self._get_adapter_info()) from e
Expand All @@ -139,26 +170,34 @@ async def get_aembedding(self, text: str) -> list[float]:
try:
kwargs = self.kwargs.copy()
model = kwargs.pop("model")
max_retries = self._pop_retry_params(kwargs)

resp = await litellm.aembedding(model=model, input=[text], **kwargs)

resp = await acall_with_retry(
lambda: litellm.aembedding(model=model, input=[text], **kwargs),
max_retries=max_retries,
retry_predicate=is_retryable_litellm_error,
description=self._get_adapter_info(),
)
return resp["data"][0]["embedding"]
except Exception as e:
provider_name = f"{self.adapter.get_name()}"
raise parse_litellm_err(e, provider_name) from e
raise parse_litellm_err(e, self._get_adapter_info()) from e

async def get_aembeddings(self, texts: list[str]) -> list[list[float]]:
"""Return async embedding vectors for list of query strings."""
try:
kwargs = self.kwargs.copy()
model = kwargs.pop("model")
max_retries = self._pop_retry_params(kwargs)

resp = await litellm.aembedding(model=model, input=texts, **kwargs)

resp = await acall_with_retry(
lambda: litellm.aembedding(model=model, input=texts, **kwargs),
max_retries=max_retries,
retry_predicate=is_retryable_litellm_error,
description=self._get_adapter_info(),
)
return [data["embedding"] for data in resp["data"]]
except Exception as e:
provider_name = f"{self.adapter.get_name()}"
raise parse_litellm_err(e, provider_name) from e
raise parse_litellm_err(e, self._get_adapter_info()) from e

def test_connection(self) -> bool:
"""Test connection to the embedding provider."""
Expand Down
60 changes: 47 additions & 13 deletions unstract/sdk1/src/unstract/sdk1/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
TokenCounterCompat,
capture_metrics,
)
from unstract.sdk1.utils.retry_utils import (
acall_with_retry,
call_with_retry,
is_retryable_litellm_error,
iter_with_retry,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -285,9 +291,12 @@ def complete(self, prompt: str, **kwargs: object) -> dict[str, object]:
# if hasattr(self, "thinking_dict") and self.thinking_dict is not None:
# completion_kwargs["temperature"] = 1

response: dict[str, object] = litellm.completion(
messages=messages,
**completion_kwargs,
max_retries = self._disable_litellm_retry(completion_kwargs)
response: dict[str, object] = call_with_retry(
lambda: litellm.completion(messages=messages, **completion_kwargs),
max_retries=max_retries,
retry_predicate=is_retryable_litellm_error,
description=self._get_adapter_info(),
)

response_text = response["choices"][0]["message"]["content"]
Expand Down Expand Up @@ -373,14 +382,18 @@ def stream_complete(
completion_kwargs = self.adapter.validate({**self.kwargs, **kwargs})
completion_kwargs.pop("cost_model", None)

max_retries = self._disable_litellm_retry(completion_kwargs)
has_yielded_content = False
for chunk in litellm.completion(
messages=messages,
stream=True,
stream_options={
"include_usage": True,
},
**completion_kwargs,
for chunk in iter_with_retry(
lambda: litellm.completion(
messages=messages,
stream=True,
stream_options={"include_usage": True},
**completion_kwargs,
),
max_retries=max_retries,
retry_predicate=is_retryable_litellm_error,
description=self._get_adapter_info(),
):
if chunk.get("usage"):
self._record_usage(
Expand Down Expand Up @@ -437,9 +450,12 @@ async def acomplete(self, prompt: str, **kwargs: object) -> dict[str, object]:
completion_kwargs = self.adapter.validate({**self.kwargs, **kwargs})
completion_kwargs.pop("cost_model", None)

response = await litellm.acompletion(
messages=messages,
**completion_kwargs,
max_retries = self._disable_litellm_retry(completion_kwargs)
response = await acall_with_retry(
lambda: litellm.acompletion(messages=messages, **completion_kwargs),
max_retries=max_retries,
retry_predicate=is_retryable_litellm_error,
description=self._get_adapter_info(),
)
response_text = response["choices"][0]["message"]["content"]
finish_reason = response["choices"][0].get("finish_reason")
Expand Down Expand Up @@ -532,6 +548,24 @@ def get_metrics(self) -> dict[str, object]:
def get_usage_reason(self) -> object:
return self.platform_kwargs.get("llm_usage_reason")

@staticmethod
def _disable_litellm_retry(kwargs: dict[str, Any]) -> int:
"""Extract max_retries from kwargs and disable litellm's own retry.

Returns the user-configured max_retries value.
"""
max_retries = kwargs.pop("max_retries", None) or 0
# Prevent SDK-level retry (OpenAI/Azure native SDK)
kwargs["max_retries"] = 0
# Prevent litellm wrapper retry (completion_with_retries)
kwargs["num_retries"] = 0
logger.debug(
"LLM: extracted max_retries=%d, "
"disabled litellm retry (max_retries=0, num_retries=0)",
max_retries,
)
return max_retries

def _record_usage(
self,
model: str,
Expand Down
Loading