Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/data_classes.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

::: nyckel.ClassificationPrediction

::: nyckel.ClassificationPredictionError

::: nyckel.ClassificationPredictionOrError

::: nyckel.ClassificationAnnotation

::: nyckel.TagsAnnotation
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ packages = ["src/nyckel"]

[project]
name = "nyckel"
version = "0.4.20"
version = "0.4.21"
authors = [{ name = "Oscar Beijbom", email = "oscar@nyckel.com" }]
description = "Python package for the Nyckel API"
readme = "README.md"
Expand Down
4 changes: 4 additions & 0 deletions src/nyckel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
ClassificationFunction, # noqa: F401
ClassificationLabel, # noqa: F401
ClassificationPrediction, # noqa: F401
ClassificationPredictionError, # noqa: F401
ClassificationPredictionOrError, # noqa: F401
ClassificationSample, # noqa: F401
ImageClassificationSample, # noqa: F401
ImageSampleData, # noqa: F401
Expand Down Expand Up @@ -46,6 +48,8 @@
"ClassificationFunction",
"ClassificationLabel",
"ClassificationPrediction",
"ClassificationPredictionError",
"ClassificationPredictionOrError",
"ClassificationSample",
"ImageClassificationSample",
"ImageSampleData",
Expand Down
9 changes: 9 additions & 0 deletions src/nyckel/functions/classification/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ class ClassificationPrediction:
confidence: float


@dataclass
class ClassificationPredictionError:
error: str # Any content / information from the server
status_code: int # HTTP status code


ClassificationPredictionOrError = Union[ClassificationPrediction, ClassificationPredictionError]


@dataclass
class ClassificationAnnotation:
label_name: str
Expand Down
3 changes: 2 additions & 1 deletion src/nyckel/functions/classification/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
ClassificationFunction,
ClassificationLabel,
ClassificationPrediction,
ClassificationPredictionOrError,
Credentials,
ImageClassificationSample,
ImageEncoder,
Expand Down Expand Up @@ -91,7 +92,7 @@ def delete(self) -> None:

def invoke( # type: ignore
self, sample_data_list: List[ImageSampleData], model_id: str = ""
) -> List[ClassificationPrediction]:
) -> List[ClassificationPredictionOrError]:
return self._sample_handler.invoke(sample_data_list, ImageSampleBodyTransformer(), model_id=model_id)

def has_trained_model(self) -> bool:
Expand Down
62 changes: 29 additions & 33 deletions src/nyckel/functions/classification/sample_handler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import time
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Callable, Dict, List, Union

import requests
from tqdm import tqdm

from nyckel import (
ClassificationPrediction,
ClassificationPredictionError,
ClassificationPredictionOrError,
ClassificationSample,
Credentials,
ImageClassificationSample,
Expand Down Expand Up @@ -32,26 +35,8 @@ def invoke(
sample_data_list: Union[List[Dict], List[str]],
sample_data_transformer: Callable,
model_id: str = "",
) -> List[ClassificationPrediction]:
n_max_attempt = 5
for _ in range(n_max_attempt):
invoke_ok, response_list = self.attempt_invoke(sample_data_list, sample_data_transformer, model_id=model_id)
if invoke_ok:
return self.parse_predictions_response(response_list)
else:
if "No model available to invoke function" in response_list[0].text:
print("Model not trained yet. Retrying...")
else:
raise RuntimeError(f"Failed to invoke function. {response_list=}")
time.sleep(5)
raise TimeoutError("Still no model after {n_max_attempt} attempts. Please try again later.")

def attempt_invoke(
self,
sample_data_list: Union[List[Dict], List[str]],
sample_data_transformer: Callable,
model_id: str = "",
) -> Tuple[bool, List[Any]]:
) -> List[ClassificationPredictionOrError]:

bodies = [{"data": sample_data} for sample_data in sample_data_list]

def body_transformer(body: Dict) -> Dict:
Expand All @@ -64,19 +49,30 @@ def body_transformer(body: Dict) -> Dict:

poster = ParallelPoster(session, endpoint, progress_bar, body_transformer)
response_list = poster(bodies)
if response_list[0].status_code in [200]:
return True, response_list
else:
return False, response_list
parsed_responses = self.parse_predictions_response(response_list)

def parse_predictions_response(self, response_list: List[Any]) -> List[ClassificationPrediction]:
return [
ClassificationPrediction(
label_name=response.json()["labelName"],
confidence=response.json()["confidence"],
)
for response in response_list
]
return parsed_responses

def parse_predictions_response(
self, response_list: List[requests.Response]
) -> List[ClassificationPredictionOrError]:
typed_responses: List[ClassificationPredictionOrError] = []
for response in response_list:
if response.status_code == 200:
typed_responses.append(
ClassificationPrediction(
label_name=response.json()["labelName"],
confidence=response.json()["confidence"],
)
)
else:
typed_responses.append(
ClassificationPredictionError(
error=response.text,
status_code=response.status_code,
)
)
return typed_responses

def create_samples(self, samples: ClassificationSampleList, sample_data_transformer: Callable) -> List[str]:
bodies = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ClassificationFunction,
ClassificationLabel,
ClassificationPrediction,
ClassificationPredictionOrError,
Credentials,
LabelName,
NyckelId,
Expand Down Expand Up @@ -104,7 +105,7 @@ def invoke( # type: ignore
self,
sample_data_list: List[TabularSampleData],
model_id: str = "",
) -> List[ClassificationPrediction]:
) -> List[ClassificationPredictionOrError]:
return self._sample_handler.invoke(
sample_data_list, self._get_image_field_transformer("name"), model_id=model_id
)
Expand Down
3 changes: 2 additions & 1 deletion src/nyckel/functions/classification/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
ClassificationFunction,
ClassificationLabel,
ClassificationPrediction,
ClassificationPredictionOrError,
Credentials,
LabelName,
NyckelId,
Expand Down Expand Up @@ -92,7 +93,7 @@ def invoke(
self,
sample_data_list: List[TextSampleData],
model_id: str = "",
) -> List[ClassificationPrediction]:
) -> List[ClassificationPredictionOrError]:
return self._sample_handler.invoke(sample_data_list, lambda x: x, model_id=model_id)

def has_trained_model(self) -> bool:
Expand Down
26 changes: 22 additions & 4 deletions src/nyckel/request_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import concurrent.futures
import time
import warnings
from json import JSONDecodeError
from typing import Callable, Dict, List, Optional, Tuple
Expand Down Expand Up @@ -31,7 +32,21 @@ def __init__(
self.progress_bar = tqdm("Posting", ncols=80)

def _post_as_json(self, data: Dict) -> requests.Response:
response = self._session.post(self._endpoint, json=self._body_transformer(data))
for attempt in range(5):
try:
response = self._session.post(self._endpoint, json=self._body_transformer(data))
if response.status_code == 200:
return response
if response.status_code == 409: # Conflict - don't retry
return response
except Exception as err:
if attempt == 4: # Last attempt failed
response = requests.Response()
response.status_code = 400
response._content = f'{{"error": "Failed to post data", "details": "{str(err)}"}}'.encode()
return response
# Wait with exponential backoff: 0.5, 1, 2, 4, 8 seconds
time.sleep(0.5 * (2**attempt))
return response

def refresh_session(self, session: requests.Session) -> None:
Expand All @@ -48,18 +63,21 @@ def __call__(self, bodies: List[Dict]) -> List[requests.Response]:
for future in concurrent.futures.as_completed(index_by_future):
index = index_by_future[future]
body = bodies[index]
body.pop("data", None) # data is too large for logs and error messages

# Pull out and truncate data to avoid blowing up logs and error messages
data = body.pop("data", None)
data = data[:100] if data else None
self.progress_bar.update(1)
try:
response = future.result()
if response.status_code not in [200, 409]:
warnings.warn(
f"Posting {body} to {self._endpoint} failed with {response.status_code=} {response.text=}",
f"Posting {data} to {self._endpoint} failed with {response.status_code=} {response.text=}",
RuntimeWarning,
)
responses[index] = response
except Exception as e:
warnings.warn(f"Posting {body} to {self._endpoint} failed with {e}", RuntimeWarning)
warnings.warn(f"Posting {data} to {self._endpoint} failed with {e}", RuntimeWarning)

return responses

Expand Down
60 changes: 59 additions & 1 deletion tests/test_text_classification_function.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import time
from typing import Tuple, Union
import unittest.mock
from typing import Dict, Tuple, Union

import pytest
import requests
from nyckel import (
ClassificationAnnotation,
ClassificationLabel,
ClassificationPrediction,
ClassificationPredictionError,
Credentials,
TextClassificationFunction,
TextClassificationSample,
)
from nyckel.functions.classification.sample_handler import ClassificationSampleHandler

post_sample_parameter_examples = [
TextClassificationSample(data="Hi neighbor!", annotation=ClassificationAnnotation(label_name="Nice")),
Expand Down Expand Up @@ -87,3 +92,56 @@ def test_end_to_end(text_classification_function: TextClassificationFunction) ->
returned_samples = func.list_samples()
assert len(samples) == len(returned_samples)
assert isinstance(returned_samples[0], TextClassificationSample)


def test_server_side_invoke_error_handling(auth_test_credentials: Credentials) -> None:

# Mock the _post_as_json method to return empty response for "invalid_input"
with unittest.mock.patch("nyckel.request_utils.ParallelPoster._post_as_json") as mock_post:

def side_effect(data: Dict) -> requests.Response:
response = requests.Response()
if data.get("data") == "server side error":
response.status_code = 400
response._content = b'{"error": "Invalid input"}'
return response

# Create successful response with JSON payload
response.status_code = 200
response._content = b'{"labelName": "Nice", "confidence": 0.9}'
return response

mock_post.side_effect = side_effect
preds = ClassificationSampleHandler("fid", auth_test_credentials).invoke(
["valid input", "server side error"], lambda x: x
)

assert len(preds) == 2
assert isinstance(preds[0], ClassificationPrediction)
assert isinstance(preds[1], ClassificationPredictionError)
assert preds[1].status_code == 400


def test_client_side_invoke_error_handling(auth_test_credentials: Credentials) -> None:

with unittest.mock.patch("requests.Session.post") as mock_post:

def side_effect(url: str, json: Dict) -> requests.Response:
response = requests.Response()
if json.get("data") == "SDK side error":
raise Exception("Client side error")
else:
response.status_code = 200
response._content = b'{"labelName": "Nice", "confidence": 0.9}'
return response

mock_post.side_effect = side_effect

preds = ClassificationSampleHandler("fid", auth_test_credentials).invoke(
["valid input", "SDK side error"], lambda x: x
)

assert len(preds) == 2
assert isinstance(preds[0], ClassificationPrediction)
assert isinstance(preds[1], ClassificationPredictionError)
assert preds[1].status_code == 400
Loading