diff --git a/decart/realtime/webrtc_connection.py b/decart/realtime/webrtc_connection.py index 12f53ed..a51e8ed 100644 --- a/decart/realtime/webrtc_connection.py +++ b/decart/realtime/webrtc_connection.py @@ -62,6 +62,10 @@ def __init__( self._local_track: Optional[MediaStreamTrack] = None self._model_name: Optional[str] = None self._connection_error: Optional[str] = None + # Per-connect() dedup: _handle_error and connect()'s except branches both + # may see the same error; whichever fires first flips this to True and the + # other skips. Reset at the top of every connect() call. + self._on_error_fired: bool = False async def connect( self, @@ -77,6 +81,7 @@ async def connect( self._local_track = local_track self._model_name = model_name self._connection_error = None + self._on_error_fired = False await self._set_state("connecting") @@ -116,10 +121,18 @@ async def connect( raise TimeoutError("Connection timeout") + except WebRTCError as e: + logger.error(f"Connection failed: {e}") + await self._set_state("disconnected") + if self._on_error and not self._on_error_fired: + self._on_error_fired = True + self._on_error(e) + raise except Exception as e: logger.error(f"Connection failed: {e}") await self._set_state("disconnected") - if self._on_error: + if self._on_error and not self._on_error_fired: + self._on_error_fired = True self._on_error(e) raise WebRTCError(str(e), cause=e) @@ -394,6 +407,7 @@ def _handle_error(self, message: ErrorMessage) -> None: self._resolve_pending_waits(message.error) if self._on_error: + self._on_error_fired = True self._on_error(error) def register_image_set_wait(self) -> tuple[asyncio.Event, dict]: diff --git a/tests/test_realtime_unit.py b/tests/test_realtime_unit.py index b748eed..d13e82d 100644 --- a/tests/test_realtime_unit.py +++ b/tests/test_realtime_unit.py @@ -1354,3 +1354,111 @@ async def inject_error_soon(): ) finally: injector.cancel() + + +@pytest.mark.asyncio +async def test_connect_does_not_double_wrap_webrtc_error(): + """WebRTCError raised inside connect() re-raises as-is — no nested cause, no duplicate on_error.""" + from decart.realtime.webrtc_connection import WebRTCConnection + from decart.errors import WebRTCError + + errors: list[Exception] = [] + connection = WebRTCConnection(on_error=lambda e: errors.append(e)) + + connection._setup_peer_connection = AsyncMock() # type: ignore[assignment] + connection._create_and_send_offer = AsyncMock() # type: ignore[assignment] + + async def _noop_receive(): + await asyncio.sleep(60) + + connection._receive_messages = _noop_receive # type: ignore[assignment] + + fake_ws = MagicMock() + fake_ws.closed = False + fake_ws.close = AsyncMock() + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session.ws_connect = AsyncMock(return_value=fake_ws) + + async def inject_error_soon(): + await asyncio.sleep(0.15) + # Simulate _handle_error having run: it sets _connection_error, fires + # on_error, and marks _on_error_fired so connect() doesn't double-fire. + connection._connection_error = "Server at capacity. Please try again later." + errors.append(WebRTCError("Server at capacity. Please try again later.")) + connection._on_error_fired = True + + injector = asyncio.create_task(inject_error_soon()) + + try: + with patch( + "decart.realtime.webrtc_connection.aiohttp.ClientSession", + return_value=mock_session, + ): + with pytest.raises(WebRTCError) as exc_info: + await connection.connect( + url="https://example.com/ws", + local_track=None, + timeout=10.0, + ) + finally: + injector.cancel() + + assert exc_info.value.message == "Server at capacity. Please try again later." + assert not isinstance(exc_info.value.cause, WebRTCError) + assert len(errors) == 1, ( + "connect()'s WebRTCError handler must not fire on_error again when _handle_error " + f"already did; got {errors!r}" + ) + + +@pytest.mark.asyncio +async def test_connect_direct_raise_fires_on_error_once(): + """Direct-raise WebRTCError paths (e.g. ack timeouts) must fire on_error exactly once.""" + from decart.realtime.webrtc_connection import WebRTCConnection + from decart.errors import WebRTCError + + errors: list[Exception] = [] + connection = WebRTCConnection(on_error=lambda e: errors.append(e)) + + async def _raise_ack_timeout(prompt, timeout=15.0): + raise WebRTCError("Initial prompt acknowledgment timed out") + + connection._send_initial_prompt_and_wait = _raise_ack_timeout # type: ignore[assignment] + connection._setup_peer_connection = AsyncMock() # type: ignore[assignment] + connection._create_and_send_offer = AsyncMock() # type: ignore[assignment] + + async def _noop_receive(): + await asyncio.sleep(60) + + connection._receive_messages = _noop_receive # type: ignore[assignment] + + fake_ws = MagicMock() + fake_ws.closed = False + fake_ws.close = AsyncMock() + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session.ws_connect = AsyncMock(return_value=fake_ws) + + with patch( + "decart.realtime.webrtc_connection.aiohttp.ClientSession", + return_value=mock_session, + ): + with pytest.raises(WebRTCError) as exc_info: + await connection.connect( + url="https://example.com/ws", + local_track=None, + timeout=10.0, + initial_prompt={"text": "hello", "enhance": True}, + ) + + assert exc_info.value.message == "Initial prompt acknowledgment timed out" + assert not isinstance(exc_info.value.cause, WebRTCError) + assert ( + len(errors) == 1 + ), f"on_error should fire exactly once for direct-raise paths; got {errors!r}" + assert errors[0] is exc_info.value