diff --git a/decart/realtime/webrtc_connection.py b/decart/realtime/webrtc_connection.py index ef5e0c5..12f53ed 100644 --- a/decart/realtime/webrtc_connection.py +++ b/decart/realtime/webrtc_connection.py @@ -61,6 +61,7 @@ def __init__( self._pending_image_set: Optional[tuple[asyncio.Event, dict]] = None self._local_track: Optional[MediaStreamTrack] = None self._model_name: Optional[str] = None + self._connection_error: Optional[str] = None async def connect( self, @@ -75,6 +76,7 @@ async def connect( try: self._local_track = local_track self._model_name = model_name + self._connection_error = None await self._set_state("connecting") @@ -108,6 +110,8 @@ async def connect( while asyncio.get_event_loop().time() < deadline: if self._state in ("connected", "generating"): return + if self._connection_error: + raise WebRTCError(self._connection_error) await asyncio.sleep(0.1) raise TimeoutError("Connection timeout") @@ -285,7 +289,8 @@ async def _receive_messages(self) -> None: if self._on_error: self._on_error(e) finally: - self._resolve_pending_waits("WebSocket disconnected") + final_error = self._connection_error or "WebSocket disconnected" + self._resolve_pending_waits(final_error) await self._set_state("disconnected") async def _handle_message(self, data: dict) -> None: @@ -384,6 +389,8 @@ def _handle_error(self, message: ErrorMessage) -> None: logger.error(f"Received error from server: {message.error}") error = WebRTCError(message.error) + if not self._connection_error: + self._connection_error = message.error self._resolve_pending_waits(message.error) if self._on_error: diff --git a/tests/test_realtime_unit.py b/tests/test_realtime_unit.py index 2972b28..b748eed 100644 --- a/tests/test_realtime_unit.py +++ b/tests/test_realtime_unit.py @@ -1266,3 +1266,91 @@ async def fake_send(message): with pytest.raises(WebRTCError, match="rate_limited"): await connection._send_initial_prompt_and_wait({"text": "test", "enhance": True}) + + +@pytest.mark.asyncio +async def test_server_error_survives_ws_disconnect_race(): + """Server error reaches the caller; receive-loop finally must not clobber it.""" + import json + import aiohttp + from decart.realtime.webrtc_connection import WebRTCConnection + + connection = WebRTCConnection() + + event, result = connection.register_image_set_wait() + + capacity_payload = json.dumps( + {"type": "error", "error": "Server at capacity. Please try again later."} + ) + + text_msg = MagicMock() + text_msg.type = aiohttp.WSMsgType.TEXT + text_msg.data = capacity_payload + + class FakeWS: + def __init__(self, messages): + self._messages = list(messages) + + def __aiter__(self): + return self + + async def __anext__(self): + if self._messages: + return self._messages.pop(0) + raise StopAsyncIteration + + connection._ws = FakeWS([text_msg]) # type: ignore[assignment] + + await connection._receive_messages() + + assert event.is_set() + assert result["success"] is False + assert result["error"] == "Server at capacity. Please try again later." + assert connection._connection_error == "Server at capacity. Please try again later." + + +@pytest.mark.asyncio +async def test_connect_raises_immediately_on_connection_error_subscribe_mode(): + """Subscribe mode: server error aborts connect() immediately, not at timeout.""" + from decart.realtime.webrtc_connection import WebRTCConnection + from decart.errors import WebRTCError + + connection = WebRTCConnection() + + connection._send_passthrough_and_wait = AsyncMock() # type: ignore[assignment] + connection._setup_peer_connection = AsyncMock() # type: ignore[assignment] + connection._create_and_send_offer = AsyncMock() # type: ignore[assignment] + + async def _noop_receive(): + await asyncio.sleep(60) + + connection._receive_messages = _noop_receive # type: ignore[assignment] + + fake_ws = MagicMock() + fake_ws.closed = False + fake_ws.close = AsyncMock() + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session.ws_connect = AsyncMock(return_value=fake_ws) + + async def inject_error_soon(): + await asyncio.sleep(0.15) + connection._connection_error = "Server at capacity. Please try again later." + + injector = asyncio.create_task(inject_error_soon()) + + try: + with patch( + "decart.realtime.webrtc_connection.aiohttp.ClientSession", + return_value=mock_session, + ): + with pytest.raises(WebRTCError, match="Server at capacity"): + await connection.connect( + url="https://example.com/ws", + local_track=None, + timeout=10.0, + ) + finally: + injector.cancel()