Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ packages = ["src/nyckel"]

[project]
name = "nyckel"
version = "0.4.21"
version = "0.4.22"
authors = [{ name = "Oscar Beijbom", email = "oscar@nyckel.com" }]
description = "Python package for the Nyckel API"
readme = "README.md"
Expand Down
3 changes: 2 additions & 1 deletion src/nyckel/functions/tags/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from nyckel.data_classes import NyckelId
from nyckel.functions.classification.classification import (
ClassificationPrediction,
ClassificationPredictionOrError,
ImageSampleData,
TabularSampleData,
TextSampleData,
Expand All @@ -16,7 +17,7 @@ class TagsAnnotation:
present: bool = True


TagsPrediction = Sequence[ClassificationPrediction]
TagsPrediction = Sequence[ClassificationPredictionOrError]


@dataclass
Expand Down
49 changes: 19 additions & 30 deletions src/nyckel/functions/tags/tags_sample_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from nyckel import (
ClassificationPrediction,
ClassificationPredictionError,
Credentials,
ImageTagsSample,
TabularTagsSample,
Expand All @@ -29,24 +30,7 @@ def __init__(self, function_id: str, credentials: Credentials) -> None:
def invoke(
self, sample_data_list: Union[List[str], List[Dict]], sample_data_transformer: Callable
) -> List[TagsPrediction]:
n_max_attempt = 5
for _ in range(n_max_attempt):
invoke_ok, response_list = self._attempt_invoke(sample_data_list, sample_data_transformer)
if invoke_ok:
return self._parse_predictions_response(response_list)
else:
if "No model available to invoke function" in response_list[0].text:
print("Model not trained yet. Retrying...")
else:
raise RuntimeError(f"Failed to invoke function. {response_list=}")
time.sleep(5)
raise TimeoutError("Still no model after {n_max_attempt} attempts. Please try again later.")

def _attempt_invoke(
self,
sample_data_list: Union[List[Dict], List[str]],
sample_data_transformer: Callable,
) -> Tuple[bool, List[Any]]:

bodies = [{"data": sample_data} for sample_data in sample_data_list]

def body_transformer(body: Dict) -> Dict:
Expand All @@ -59,23 +43,28 @@ def body_transformer(body: Dict) -> Dict:

poster = ParallelPoster(session, endpoint, progress_bar, body_transformer)
response_list = poster(bodies)
if response_list[0].status_code in [200]:
return True, response_list
else:
return False, response_list
return self._parse_predictions_response(response_list)

def _parse_predictions_response(self, response_list: List[Any]) -> List[TagsPrediction]:
tags_predictions: List[TagsPrediction] = []
for response in response_list:
tags_prediction = [
ClassificationPrediction(
label_name=entry["labelName"],
confidence=entry["confidence"],
)
for entry in response.json()
]
tags_prediction: TagsPrediction
if response.status_code == 200:
tags_prediction = [
ClassificationPrediction(
label_name=entry["labelName"],
confidence=entry["confidence"],
)
for entry in response.json()
]
else:
tags_prediction = [
ClassificationPredictionError(
error=response.text,
status_code=response.status_code,
)
]
tags_predictions.append(tags_prediction)

return tags_predictions

def create_samples(self, samples: TagsSampleList, sample_data_transformer: Callable) -> List[str]:
Expand Down
2 changes: 1 addition & 1 deletion src/nyckel/request_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __call__(self, bodies: List[Dict]) -> List[requests.Response]:

# Pull out and truncate data to avoid blowing up logs and error messages
data = body.pop("data", None)
data = data[:100] if data else None
data = str(data)[:100] + "..." if data else None
self.progress_bar.update(1)
try:
response = future.result()
Expand Down
35 changes: 35 additions & 0 deletions tests/test_text_tags_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import unittest.mock
from typing import Dict

import requests
from nyckel import ClassificationPrediction, ClassificationPredictionError, Credentials, TagsPrediction
from nyckel.functions.tags.tags_sample_handler import TagsSampleHandler


def test_server_side_invoke_error_handling(auth_test_credentials: Credentials) -> None:

# Mock the _post_as_json method to return empty response for "invalid_input"
with unittest.mock.patch("nyckel.request_utils.ParallelPoster._post_as_json") as mock_post:

def side_effect(data: Dict) -> requests.Response:
response = requests.Response()
if data.get("data") == "server side error":
response.status_code = 400
response._content = b'"Invalid input"'
return response

# Create successful response with JSON payload
response.status_code = 200
response._content = (
b'[{"labelName": "Nice", "confidence": 0.9}, {"labelName": "Friendly", "confidence": 0.8}]'
)
return response

mock_post.side_effect = side_effect
sample_handler = TagsSampleHandler("function_id", auth_test_credentials)
preds = sample_handler.invoke(["valid input", "server side error"], lambda x: x)
assert len(preds) == 2
assert len(preds[0]) == 2
assert isinstance(preds[0][0], ClassificationPrediction)
assert len(preds[1]) == 1
assert isinstance(preds[1][0], ClassificationPredictionError)