Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion decart/realtime/webrtc_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def __init__(
self._local_track: Optional[MediaStreamTrack] = None
self._model_name: Optional[str] = None
self._connection_error: Optional[str] = None
# Per-connect() dedup: _handle_error and connect()'s except branches both
# may see the same error; whichever fires first flips this to True and the
# other skips. Reset at the top of every connect() call.
self._on_error_fired: bool = False

async def connect(
self,
Expand All @@ -77,6 +81,7 @@ async def connect(
self._local_track = local_track
self._model_name = model_name
self._connection_error = None
self._on_error_fired = False

await self._set_state("connecting")

Expand Down Expand Up @@ -116,10 +121,18 @@ async def connect(

raise TimeoutError("Connection timeout")

except WebRTCError as e:
logger.error(f"Connection failed: {e}")
await self._set_state("disconnected")
if self._on_error and not self._on_error_fired:
self._on_error_fired = True
self._on_error(e)
raise
Comment thread
cursor[bot] marked this conversation as resolved.
except Exception as e:
logger.error(f"Connection failed: {e}")
await self._set_state("disconnected")
if self._on_error:
if self._on_error and not self._on_error_fired:
self._on_error_fired = True
self._on_error(e)
raise WebRTCError(str(e), cause=e)

Expand Down Expand Up @@ -394,6 +407,7 @@ def _handle_error(self, message: ErrorMessage) -> None:
self._resolve_pending_waits(message.error)

if self._on_error:
self._on_error_fired = True
self._on_error(error)

def register_image_set_wait(self) -> tuple[asyncio.Event, dict]:
Expand Down
108 changes: 108 additions & 0 deletions tests/test_realtime_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,3 +1354,111 @@ async def inject_error_soon():
)
finally:
injector.cancel()


@pytest.mark.asyncio
async def test_connect_does_not_double_wrap_webrtc_error():
"""WebRTCError raised inside connect() re-raises as-is — no nested cause, no duplicate on_error."""
from decart.realtime.webrtc_connection import WebRTCConnection
from decart.errors import WebRTCError

errors: list[Exception] = []
connection = WebRTCConnection(on_error=lambda e: errors.append(e))

connection._setup_peer_connection = AsyncMock() # type: ignore[assignment]
connection._create_and_send_offer = AsyncMock() # type: ignore[assignment]

async def _noop_receive():
await asyncio.sleep(60)

connection._receive_messages = _noop_receive # type: ignore[assignment]

fake_ws = MagicMock()
fake_ws.closed = False
fake_ws.close = AsyncMock()

mock_session = MagicMock()
mock_session.closed = False
mock_session.close = AsyncMock()
mock_session.ws_connect = AsyncMock(return_value=fake_ws)

async def inject_error_soon():
await asyncio.sleep(0.15)
# Simulate _handle_error having run: it sets _connection_error, fires
# on_error, and marks _on_error_fired so connect() doesn't double-fire.
connection._connection_error = "Server at capacity. Please try again later."
errors.append(WebRTCError("Server at capacity. Please try again later."))
connection._on_error_fired = True

injector = asyncio.create_task(inject_error_soon())

try:
with patch(
"decart.realtime.webrtc_connection.aiohttp.ClientSession",
return_value=mock_session,
):
with pytest.raises(WebRTCError) as exc_info:
await connection.connect(
url="https://example.com/ws",
local_track=None,
timeout=10.0,
)
finally:
injector.cancel()

assert exc_info.value.message == "Server at capacity. Please try again later."
assert not isinstance(exc_info.value.cause, WebRTCError)
assert len(errors) == 1, (
"connect()'s WebRTCError handler must not fire on_error again when _handle_error "
f"already did; got {errors!r}"
)


@pytest.mark.asyncio
async def test_connect_direct_raise_fires_on_error_once():
"""Direct-raise WebRTCError paths (e.g. ack timeouts) must fire on_error exactly once."""
from decart.realtime.webrtc_connection import WebRTCConnection
from decart.errors import WebRTCError

errors: list[Exception] = []
connection = WebRTCConnection(on_error=lambda e: errors.append(e))

async def _raise_ack_timeout(prompt, timeout=15.0):
raise WebRTCError("Initial prompt acknowledgment timed out")

connection._send_initial_prompt_and_wait = _raise_ack_timeout # type: ignore[assignment]
connection._setup_peer_connection = AsyncMock() # type: ignore[assignment]
connection._create_and_send_offer = AsyncMock() # type: ignore[assignment]

async def _noop_receive():
await asyncio.sleep(60)

connection._receive_messages = _noop_receive # type: ignore[assignment]

fake_ws = MagicMock()
fake_ws.closed = False
fake_ws.close = AsyncMock()

mock_session = MagicMock()
mock_session.closed = False
mock_session.close = AsyncMock()
mock_session.ws_connect = AsyncMock(return_value=fake_ws)

with patch(
"decart.realtime.webrtc_connection.aiohttp.ClientSession",
return_value=mock_session,
):
with pytest.raises(WebRTCError) as exc_info:
await connection.connect(
url="https://example.com/ws",
local_track=None,
timeout=10.0,
initial_prompt={"text": "hello", "enhance": True},
)

assert exc_info.value.message == "Initial prompt acknowledgment timed out"
assert not isinstance(exc_info.value.cause, WebRTCError)
assert (
len(errors) == 1
), f"on_error should fire exactly once for direct-raise paths; got {errors!r}"
assert errors[0] is exc_info.value
Loading