Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion decart/realtime/webrtc_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
self._pending_image_set: Optional[tuple[asyncio.Event, dict]] = None
self._local_track: Optional[MediaStreamTrack] = None
self._model_name: Optional[str] = None
self._connection_error: Optional[str] = None

async def connect(
self,
Expand All @@ -75,6 +76,7 @@ async def connect(
try:
self._local_track = local_track
self._model_name = model_name
self._connection_error = None

await self._set_state("connecting")

Expand Down Expand Up @@ -108,6 +110,8 @@ async def connect(
while asyncio.get_event_loop().time() < deadline:
if self._state in ("connected", "generating"):
return
if self._connection_error:
raise WebRTCError(self._connection_error)
Comment thread
AdirAmsalem marked this conversation as resolved.
await asyncio.sleep(0.1)

raise TimeoutError("Connection timeout")
Expand Down Expand Up @@ -285,7 +289,8 @@ async def _receive_messages(self) -> None:
if self._on_error:
self._on_error(e)
finally:
self._resolve_pending_waits("WebSocket disconnected")
final_error = self._connection_error or "WebSocket disconnected"
self._resolve_pending_waits(final_error)
await self._set_state("disconnected")

async def _handle_message(self, data: dict) -> None:
Expand Down Expand Up @@ -384,6 +389,8 @@ def _handle_error(self, message: ErrorMessage) -> None:
logger.error(f"Received error from server: {message.error}")
error = WebRTCError(message.error)

if not self._connection_error:
self._connection_error = message.error
self._resolve_pending_waits(message.error)

if self._on_error:
Expand Down
88 changes: 88 additions & 0 deletions tests/test_realtime_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,3 +1266,91 @@ async def fake_send(message):

with pytest.raises(WebRTCError, match="rate_limited"):
await connection._send_initial_prompt_and_wait({"text": "test", "enhance": True})


@pytest.mark.asyncio
async def test_server_error_survives_ws_disconnect_race():
"""Server error reaches the caller; receive-loop finally must not clobber it."""
import json
import aiohttp
from decart.realtime.webrtc_connection import WebRTCConnection

connection = WebRTCConnection()

event, result = connection.register_image_set_wait()

capacity_payload = json.dumps(
{"type": "error", "error": "Server at capacity. Please try again later."}
)

text_msg = MagicMock()
text_msg.type = aiohttp.WSMsgType.TEXT
text_msg.data = capacity_payload

class FakeWS:
def __init__(self, messages):
self._messages = list(messages)

def __aiter__(self):
return self

async def __anext__(self):
if self._messages:
return self._messages.pop(0)
raise StopAsyncIteration

connection._ws = FakeWS([text_msg]) # type: ignore[assignment]

await connection._receive_messages()

assert event.is_set()
assert result["success"] is False
assert result["error"] == "Server at capacity. Please try again later."
assert connection._connection_error == "Server at capacity. Please try again later."


@pytest.mark.asyncio
async def test_connect_raises_immediately_on_connection_error_subscribe_mode():
"""Subscribe mode: server error aborts connect() immediately, not at timeout."""
from decart.realtime.webrtc_connection import WebRTCConnection
from decart.errors import WebRTCError

connection = WebRTCConnection()

connection._send_passthrough_and_wait = AsyncMock() # type: ignore[assignment]
connection._setup_peer_connection = AsyncMock() # type: ignore[assignment]
connection._create_and_send_offer = AsyncMock() # type: ignore[assignment]

async def _noop_receive():
await asyncio.sleep(60)

connection._receive_messages = _noop_receive # type: ignore[assignment]

fake_ws = MagicMock()
fake_ws.closed = False
fake_ws.close = AsyncMock()

mock_session = MagicMock()
mock_session.closed = False
mock_session.close = AsyncMock()
mock_session.ws_connect = AsyncMock(return_value=fake_ws)

async def inject_error_soon():
await asyncio.sleep(0.15)
connection._connection_error = "Server at capacity. Please try again later."

injector = asyncio.create_task(inject_error_soon())

try:
with patch(
"decart.realtime.webrtc_connection.aiohttp.ClientSession",
return_value=mock_session,
):
with pytest.raises(WebRTCError, match="Server at capacity"):
await connection.connect(
url="https://example.com/ws",
local_track=None,
timeout=10.0,
)
finally:
injector.cancel()
Loading