From 65cfbc346da9d50511755ecba95c9c7037615241 Mon Sep 17 00:00:00 2001 From: Adir Amsalem Date: Mon, 20 Apr 2026 13:37:25 +0300 Subject: [PATCH] fix: surface server capacity error through connect() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When the server rejects a connection at capacity, it sends a structured error message and then closes the WebSocket. The receive loop's finally block was overwriting the specific error with a generic "WebSocket disconnected" before the awaiting caller could read it, and subscribe mode had no pending waits to surface the error at all — callers hung until the 5-minute connection timeout. Store the first specific error on the connection and have both the finally block and connect()'s wait loop defer to it, matching the JS SDK's "first reject wins" semantics. --- decart/realtime/webrtc_connection.py | 9 ++- tests/test_realtime_unit.py | 88 ++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/decart/realtime/webrtc_connection.py b/decart/realtime/webrtc_connection.py index ef5e0c5..12f53ed 100644 --- a/decart/realtime/webrtc_connection.py +++ b/decart/realtime/webrtc_connection.py @@ -61,6 +61,7 @@ def __init__( self._pending_image_set: Optional[tuple[asyncio.Event, dict]] = None self._local_track: Optional[MediaStreamTrack] = None self._model_name: Optional[str] = None + self._connection_error: Optional[str] = None async def connect( self, @@ -75,6 +76,7 @@ async def connect( try: self._local_track = local_track self._model_name = model_name + self._connection_error = None await self._set_state("connecting") @@ -108,6 +110,8 @@ async def connect( while asyncio.get_event_loop().time() < deadline: if self._state in ("connected", "generating"): return + if self._connection_error: + raise WebRTCError(self._connection_error) await asyncio.sleep(0.1) raise TimeoutError("Connection timeout") @@ -285,7 +289,8 @@ async def _receive_messages(self) -> None: if self._on_error: self._on_error(e) finally: - self._resolve_pending_waits("WebSocket disconnected") + final_error = self._connection_error or "WebSocket disconnected" + self._resolve_pending_waits(final_error) await self._set_state("disconnected") async def _handle_message(self, data: dict) -> None: @@ -384,6 +389,8 @@ def _handle_error(self, message: ErrorMessage) -> None: logger.error(f"Received error from server: {message.error}") error = WebRTCError(message.error) + if not self._connection_error: + self._connection_error = message.error self._resolve_pending_waits(message.error) if self._on_error: diff --git a/tests/test_realtime_unit.py b/tests/test_realtime_unit.py index 2972b28..b748eed 100644 --- a/tests/test_realtime_unit.py +++ b/tests/test_realtime_unit.py @@ -1266,3 +1266,91 @@ async def fake_send(message): with pytest.raises(WebRTCError, match="rate_limited"): await connection._send_initial_prompt_and_wait({"text": "test", "enhance": True}) + + +@pytest.mark.asyncio +async def test_server_error_survives_ws_disconnect_race(): + """Server error reaches the caller; receive-loop finally must not clobber it.""" + import json + import aiohttp + from decart.realtime.webrtc_connection import WebRTCConnection + + connection = WebRTCConnection() + + event, result = connection.register_image_set_wait() + + capacity_payload = json.dumps( + {"type": "error", "error": "Server at capacity. Please try again later."} + ) + + text_msg = MagicMock() + text_msg.type = aiohttp.WSMsgType.TEXT + text_msg.data = capacity_payload + + class FakeWS: + def __init__(self, messages): + self._messages = list(messages) + + def __aiter__(self): + return self + + async def __anext__(self): + if self._messages: + return self._messages.pop(0) + raise StopAsyncIteration + + connection._ws = FakeWS([text_msg]) # type: ignore[assignment] + + await connection._receive_messages() + + assert event.is_set() + assert result["success"] is False + assert result["error"] == "Server at capacity. Please try again later." + assert connection._connection_error == "Server at capacity. Please try again later." + + +@pytest.mark.asyncio +async def test_connect_raises_immediately_on_connection_error_subscribe_mode(): + """Subscribe mode: server error aborts connect() immediately, not at timeout.""" + from decart.realtime.webrtc_connection import WebRTCConnection + from decart.errors import WebRTCError + + connection = WebRTCConnection() + + connection._send_passthrough_and_wait = AsyncMock() # type: ignore[assignment] + connection._setup_peer_connection = AsyncMock() # type: ignore[assignment] + connection._create_and_send_offer = AsyncMock() # type: ignore[assignment] + + async def _noop_receive(): + await asyncio.sleep(60) + + connection._receive_messages = _noop_receive # type: ignore[assignment] + + fake_ws = MagicMock() + fake_ws.closed = False + fake_ws.close = AsyncMock() + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session.ws_connect = AsyncMock(return_value=fake_ws) + + async def inject_error_soon(): + await asyncio.sleep(0.15) + connection._connection_error = "Server at capacity. Please try again later." + + injector = asyncio.create_task(inject_error_soon()) + + try: + with patch( + "decart.realtime.webrtc_connection.aiohttp.ClientSession", + return_value=mock_session, + ): + with pytest.raises(WebRTCError, match="Server at capacity"): + await connection.connect( + url="https://example.com/ws", + local_track=None, + timeout=10.0, + ) + finally: + injector.cancel()