diff --git a/pyproject.toml b/pyproject.toml index b3c8bd3..fe12c50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ packages = ["src/nyckel"] [project] name = "nyckel" -version = "0.4.21" +version = "0.4.22" authors = [{ name = "Oscar Beijbom", email = "oscar@nyckel.com" }] description = "Python package for the Nyckel API" readme = "README.md" diff --git a/src/nyckel/functions/tags/tags.py b/src/nyckel/functions/tags/tags.py index f38d9f3..ec4f6e6 100644 --- a/src/nyckel/functions/tags/tags.py +++ b/src/nyckel/functions/tags/tags.py @@ -4,6 +4,7 @@ from nyckel.data_classes import NyckelId from nyckel.functions.classification.classification import ( ClassificationPrediction, + ClassificationPredictionOrError, ImageSampleData, TabularSampleData, TextSampleData, @@ -16,7 +17,7 @@ class TagsAnnotation: present: bool = True -TagsPrediction = Sequence[ClassificationPrediction] +TagsPrediction = Sequence[ClassificationPredictionOrError] @dataclass diff --git a/src/nyckel/functions/tags/tags_sample_handler.py b/src/nyckel/functions/tags/tags_sample_handler.py index 12cee7b..9f054a5 100644 --- a/src/nyckel/functions/tags/tags_sample_handler.py +++ b/src/nyckel/functions/tags/tags_sample_handler.py @@ -5,6 +5,7 @@ from nyckel import ( ClassificationPrediction, + ClassificationPredictionError, Credentials, ImageTagsSample, TabularTagsSample, @@ -29,24 +30,7 @@ def __init__(self, function_id: str, credentials: Credentials) -> None: def invoke( self, sample_data_list: Union[List[str], List[Dict]], sample_data_transformer: Callable ) -> List[TagsPrediction]: - n_max_attempt = 5 - for _ in range(n_max_attempt): - invoke_ok, response_list = self._attempt_invoke(sample_data_list, sample_data_transformer) - if invoke_ok: - return self._parse_predictions_response(response_list) - else: - if "No model available to invoke function" in response_list[0].text: - print("Model not trained yet. Retrying...") - else: - raise RuntimeError(f"Failed to invoke function. {response_list=}") - time.sleep(5) - raise TimeoutError("Still no model after {n_max_attempt} attempts. Please try again later.") - - def _attempt_invoke( - self, - sample_data_list: Union[List[Dict], List[str]], - sample_data_transformer: Callable, - ) -> Tuple[bool, List[Any]]: + bodies = [{"data": sample_data} for sample_data in sample_data_list] def body_transformer(body: Dict) -> Dict: @@ -59,23 +43,28 @@ def body_transformer(body: Dict) -> Dict: poster = ParallelPoster(session, endpoint, progress_bar, body_transformer) response_list = poster(bodies) - if response_list[0].status_code in [200]: - return True, response_list - else: - return False, response_list + return self._parse_predictions_response(response_list) def _parse_predictions_response(self, response_list: List[Any]) -> List[TagsPrediction]: tags_predictions: List[TagsPrediction] = [] for response in response_list: - tags_prediction = [ - ClassificationPrediction( - label_name=entry["labelName"], - confidence=entry["confidence"], - ) - for entry in response.json() - ] + tags_prediction: TagsPrediction + if response.status_code == 200: + tags_prediction = [ + ClassificationPrediction( + label_name=entry["labelName"], + confidence=entry["confidence"], + ) + for entry in response.json() + ] + else: + tags_prediction = [ + ClassificationPredictionError( + error=response.text, + status_code=response.status_code, + ) + ] tags_predictions.append(tags_prediction) - return tags_predictions def create_samples(self, samples: TagsSampleList, sample_data_transformer: Callable) -> List[str]: diff --git a/src/nyckel/request_utils.py b/src/nyckel/request_utils.py index 4152aff..d993fee 100644 --- a/src/nyckel/request_utils.py +++ b/src/nyckel/request_utils.py @@ -66,7 +66,7 @@ def __call__(self, bodies: List[Dict]) -> List[requests.Response]: # Pull out and truncate data to avoid blowing up logs and error messages data = body.pop("data", None) - data = data[:100] if data else None + data = str(data)[:100] + "..." if data else None self.progress_bar.update(1) try: response = future.result() diff --git a/tests/test_text_tags_function.py b/tests/test_text_tags_function.py new file mode 100644 index 0000000..586f341 --- /dev/null +++ b/tests/test_text_tags_function.py @@ -0,0 +1,35 @@ +import unittest.mock +from typing import Dict + +import requests +from nyckel import ClassificationPrediction, ClassificationPredictionError, Credentials, TagsPrediction +from nyckel.functions.tags.tags_sample_handler import TagsSampleHandler + + +def test_server_side_invoke_error_handling(auth_test_credentials: Credentials) -> None: + + # Mock the _post_as_json method to return empty response for "invalid_input" + with unittest.mock.patch("nyckel.request_utils.ParallelPoster._post_as_json") as mock_post: + + def side_effect(data: Dict) -> requests.Response: + response = requests.Response() + if data.get("data") == "server side error": + response.status_code = 400 + response._content = b'"Invalid input"' + return response + + # Create successful response with JSON payload + response.status_code = 200 + response._content = ( + b'[{"labelName": "Nice", "confidence": 0.9}, {"labelName": "Friendly", "confidence": 0.8}]' + ) + return response + + mock_post.side_effect = side_effect + sample_handler = TagsSampleHandler("function_id", auth_test_credentials) + preds = sample_handler.invoke(["valid input", "server side error"], lambda x: x) + assert len(preds) == 2 + assert len(preds[0]) == 2 + assert isinstance(preds[0][0], ClassificationPrediction) + assert len(preds[1]) == 1 + assert isinstance(preds[1][0], ClassificationPredictionError)