diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5527f9da1..174256880 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,8 @@ Changelog Next ---- +* Add ``request_timeout_seconds`` parameter to ``VWS`` and ``CloudRecoService``, allowing customization of the request timeout. This accepts a float or a ``(connect, read)`` tuple, matching the ``requests`` library's timeout interface. The default remains 30 seconds. + 2025.03.10.1 ------------ diff --git a/src/vws/query.py b/src/vws/query.py index 6c2dc481e..dd36e8f21 100644 --- a/src/vws/query.py +++ b/src/vws/query.py @@ -8,7 +8,7 @@ from urllib.parse import urljoin import requests -from beartype import beartype +from beartype import BeartypeConf, beartype from urllib3.filepost import encode_multipart_formdata from vws_auth_tools import authorization_header, rfc_1123_date @@ -40,7 +40,7 @@ def _get_image_data(image: _ImageType) -> bytes: return image_data -@beartype +@beartype(conf=BeartypeConf(is_pep484_tower=True)) class CloudRecoService: """An interface to the Vuforia Cloud Recognition Web APIs.""" @@ -49,16 +49,22 @@ def __init__( client_access_key: str, client_secret_key: str, base_vwq_url: str = "https://cloudreco.vuforia.com", + request_timeout_seconds: float | tuple[float, float] = 30.0, ) -> None: """ Args: client_access_key: A VWS client access key. client_secret_key: A VWS client secret key. base_vwq_url: The base URL for the VWQ API. + request_timeout_seconds: The timeout for each HTTP request, as + used by ``requests.request``. This can be a float to set + both the connect and read timeouts, or a (connect, read) + tuple. """ self._client_access_key = client_access_key self._client_secret_key = client_secret_key self._base_vwq_url = base_vwq_url + self._request_timeout_seconds = request_timeout_seconds def query( self, @@ -141,8 +147,7 @@ def query( url=urljoin(base=self._base_vwq_url, url=request_path), headers=headers, data=content, - # We should make the timeout customizable. - timeout=30, + timeout=self._request_timeout_seconds, ) response = Response( text=requests_response.text, diff --git a/src/vws/vws.py b/src/vws/vws.py index 9b3d19b4e..4cb75ca32 100644 --- a/src/vws/vws.py +++ b/src/vws/vws.py @@ -58,7 +58,7 @@ def _get_image_data(image: _ImageType) -> bytes: return image_data -@beartype +@beartype(conf=BeartypeConf(is_pep484_tower=True)) def _target_api_request( *, content_type: str, @@ -68,6 +68,7 @@ def _target_api_request( data: bytes, request_path: str, base_vws_url: str, + request_timeout_seconds: float | tuple[float, float], ) -> Response: """Make a request to the Vuforia Target API. @@ -82,6 +83,9 @@ def _target_api_request( request_path: The path to the endpoint which will be used in the request. base_vws_url: The base URL for the VWS API. + request_timeout_seconds: The timeout for the request, as used by + ``requests.request``. This can be a float to set both the + connect and read timeouts, or a (connect, read) tuple. Returns: The response to the request made by `requests`. @@ -111,8 +115,7 @@ def _target_api_request( url=url, headers=headers, data=data, - # We should make the timeout customizable. - timeout=30, + timeout=request_timeout_seconds, ) return Response( @@ -134,16 +137,22 @@ def __init__( server_access_key: str, server_secret_key: str, base_vws_url: str = "https://vws.vuforia.com", + request_timeout_seconds: float | tuple[float, float] = 30.0, ) -> None: """ Args: server_access_key: A VWS server access key. server_secret_key: A VWS server secret key. base_vws_url: The base URL for the VWS API. + request_timeout_seconds: The timeout for each HTTP request, as + used by ``requests.request``. This can be a float to set + both the connect and read timeouts, or a (connect, read) + tuple. """ self._server_access_key = server_access_key self._server_secret_key = server_secret_key self._base_vws_url = base_vws_url + self._request_timeout_seconds = request_timeout_seconds def make_request( self, @@ -187,6 +196,7 @@ def make_request( data=data, request_path=request_path, base_vws_url=self._base_vws_url, + request_timeout_seconds=self._request_timeout_seconds, ) if ( diff --git a/tests/test_query.py b/tests/test_query.py index 31284325d..d7b64aff2 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -92,6 +92,61 @@ def test_default_timeout( assert not matches +class TestCustomRequestTimeout: + """Tests for custom request timeout values.""" + + @staticmethod + @pytest.mark.parametrize( + argnames=( + "custom_timeout", + "response_delay_seconds", + "expect_timeout", + ), + argvalues=[ + (0.1, 0.09, False), + (0.1, 0.11, True), + ((5.0, 0.1), 0.09, False), + ((5.0, 0.1), 0.11, True), + ], + ) + def test_custom_timeout( + image: io.BytesIO | BinaryIO, + *, + custom_timeout: float | tuple[float, float], + response_delay_seconds: float, + expect_timeout: bool, + ) -> None: + """Custom timeouts are honored for both float and tuple forms.""" + with ( + freeze_time() as frozen_datetime, + MockVWS( + response_delay_seconds=response_delay_seconds, + sleep_fn=lambda seconds: ( + frozen_datetime.tick( + delta=datetime.timedelta(seconds=seconds), + ), + None, + )[1], + ) as mock, + ): + database = VuforiaDatabase() + mock.add_database(database=database) + cloud_reco_client = CloudRecoService( + client_access_key=database.client_access_key, + client_secret_key=database.client_secret_key, + request_timeout_seconds=custom_timeout, + ) + + if expect_timeout: + with pytest.raises( + expected_exception=requests.exceptions.Timeout, + ): + cloud_reco_client.query(image=image) + else: + matches = cloud_reco_client.query(image=image) + assert not matches + + class TestCustomBaseVWQURL: """Tests for using a custom base VWQ URL.""" diff --git a/tests/test_vws.py b/tests/test_vws.py index 9a1c8ded9..f09906ee5 100644 --- a/tests/test_vws.py +++ b/tests/test_vws.py @@ -150,6 +150,72 @@ def test_default_timeout( ) +class TestCustomRequestTimeout: + """Tests for custom request timeout values.""" + + @staticmethod + @pytest.mark.parametrize( + argnames=( + "custom_timeout", + "response_delay_seconds", + "expect_timeout", + ), + argvalues=[ + (0.1, 0.09, False), + (0.1, 0.11, True), + ((5.0, 0.1), 0.09, False), + ((5.0, 0.1), 0.11, True), + ], + ) + def test_custom_timeout( + image: io.BytesIO | BinaryIO, + *, + custom_timeout: float | tuple[float, float], + response_delay_seconds: float, + expect_timeout: bool, + ) -> None: + """Custom timeouts are honored for both float and tuple forms.""" + with ( + freeze_time() as frozen_datetime, + MockVWS( + response_delay_seconds=response_delay_seconds, + sleep_fn=lambda seconds: ( + frozen_datetime.tick( + delta=datetime.timedelta(seconds=seconds), + ), + None, + )[1], + ) as mock, + ): + database = VuforiaDatabase() + mock.add_database(database=database) + vws_client = VWS( + server_access_key=database.server_access_key, + server_secret_key=database.server_secret_key, + request_timeout_seconds=custom_timeout, + ) + + if expect_timeout: + with pytest.raises( + expected_exception=requests.exceptions.Timeout, + ): + vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + else: + vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + + class TestCustomBaseVWSURL: """Tests for using a custom base VWS URL."""