From 4d5f881d4fbbac0862c14b78fd8015a4a24162c2 Mon Sep 17 00:00:00 2001 From: "helen@cloud" Date: Wed, 6 May 2026 21:28:20 +0800 Subject: [PATCH 1/2] refactor: centralize interrupt request helper binding --- src/opencode_a2a/execution/stream_runtime.py | 32 +++----- src/opencode_a2a/interrupt_request_tracker.py | 79 ++++++++++++++++++ .../jsonrpc/handlers/interrupt_callbacks.py | 46 +++++------ tests/execution/test_metrics.py | 4 + tests/execution/test_multipart_input.py | 4 + ...st_streaming_output_contract_interrupts.py | 4 +- ...t_opencode_session_extension_interrupts.py | 80 +++++++++++++++++++ tests/support/helpers.py | 37 +-------- tests/support/streaming_output.py | 37 +-------- 9 files changed, 206 insertions(+), 117 deletions(-) create mode 100644 src/opencode_a2a/interrupt_request_tracker.py diff --git a/src/opencode_a2a/execution/stream_runtime.py b/src/opencode_a2a/execution/stream_runtime.py index a648d1b..950abc0 100644 --- a/src/opencode_a2a/execution/stream_runtime.py +++ b/src/opencode_a2a/execution/stream_runtime.py @@ -16,6 +16,7 @@ ) from ..a2a_utils import make_data_part +from ..interrupt_request_tracker import bind_interrupt_request_tracker from .event_helpers import _enqueue_artifact_update from .stream_events import ( BlockType, @@ -85,6 +86,7 @@ async def consume( pending_deltas: defaultdict[str, list[_PendingDelta]] = defaultdict(list) backoff = 0.5 max_backoff = 5.0 + interrupt_requests = bind_interrupt_request_tracker(self._client) async def _emit_chunks(chunks: list[_NormalizedStreamChunk]) -> None: for chunk in chunks: @@ -490,21 +492,15 @@ def _tool_chunks( if asked is not None: request_id = asked["request_id"] if stream_state.mark_interrupt_pending(request_id): - remember_request = getattr( - self._client, - "remember_interrupt_request", - None, + await interrupt_requests.remember_request( + request_id=request_id, + session_id=session_id, + interrupt_type=asked["interrupt_type"], + identity=identity, + task_id=task_id, + context_id=context_id, + details=asked["details"], ) - if callable(remember_request): - await remember_request( - request_id=request_id, - session_id=session_id, - interrupt_type=asked["interrupt_type"], - identity=identity, - task_id=task_id, - context_id=context_id, - details=asked["details"], - ) await _emit_interrupt_status( state=TaskState.TASK_STATE_INPUT_REQUIRED, request_id=request_id, @@ -518,13 +514,7 @@ def _tool_chunks( cleared_pending = stream_state.clear_interrupt_pending( resolved_request_id ) - discard_request = getattr( - self._client, - "discard_interrupt_request", - None, - ) - if callable(discard_request): - await discard_request(resolved_request_id) + await interrupt_requests.discard_request(resolved_request_id) if cleared_pending: await _emit_interrupt_status( state=TaskState.TASK_STATE_WORKING, diff --git a/src/opencode_a2a/interrupt_request_tracker.py b/src/opencode_a2a/interrupt_request_tracker.py new file mode 100644 index 0000000..74dbf31 --- /dev/null +++ b/src/opencode_a2a/interrupt_request_tracker.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal + +from .runtime_state import InterruptRequestBinding + + +@dataclass(frozen=True) +class InterruptRequestResolution: + status: Literal["active", "expired", "missing"] + binding: InterruptRequestBinding | None + + +class BoundInterruptRequestTracker: + def __init__(self, client: object) -> None: + remember_request = getattr(client, "remember_interrupt_request", None) + resolve_request = getattr(client, "resolve_interrupt_request", None) + resolve_session = getattr(client, "resolve_interrupt_session", None) + discard_request = getattr(client, "discard_interrupt_request", None) + + self._remember_request = remember_request if callable(remember_request) else None + self._resolve_request = resolve_request if callable(resolve_request) else None + self._resolve_session = resolve_session if callable(resolve_session) else None + self._discard_request = discard_request if callable(discard_request) else None + + async def remember_request( + self, + *, + request_id: str, + session_id: str, + interrupt_type: str, + identity: str | None = None, + credential_id: str | None = None, + task_id: str | None = None, + context_id: str | None = None, + details: dict[str, Any] | None = None, + ttl_seconds: float | None = None, + ) -> None: + if self._remember_request is None: + return + request_kwargs: dict[str, Any] = { + "request_id": request_id, + "session_id": session_id, + "interrupt_type": interrupt_type, + } + if identity is not None: + request_kwargs["identity"] = identity + if credential_id is not None: + request_kwargs["credential_id"] = credential_id + if task_id is not None: + request_kwargs["task_id"] = task_id + if context_id is not None: + request_kwargs["context_id"] = context_id + if details is not None: + request_kwargs["details"] = details + if ttl_seconds is not None: + request_kwargs["ttl_seconds"] = ttl_seconds + await self._remember_request(**request_kwargs) + + async def resolve_request(self, request_id: str) -> InterruptRequestResolution: + if self._resolve_request is not None: + status, binding = await self._resolve_request(request_id) + return InterruptRequestResolution(status=status, binding=binding) + if self._resolve_session is not None: + # Keep session-only clients working while newer call sites prefer bindings. + session_id = await self._resolve_session(request_id) + if isinstance(session_id, str) and session_id: + return InterruptRequestResolution(status="active", binding=None) + return InterruptRequestResolution(status="missing", binding=None) + + async def discard_request(self, request_id: str) -> None: + if self._discard_request is None: + return + await self._discard_request(request_id) + + +def bind_interrupt_request_tracker(client: object) -> BoundInterruptRequestTracker: + return BoundInterruptRequestTracker(client) diff --git a/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py b/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py index 3b704db..7e9e159 100644 --- a/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py +++ b/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py @@ -8,6 +8,7 @@ from starlette.responses import Response from ...contracts.extensions import INTERRUPT_ERROR_BUSINESS_CODES +from ...interrupt_request_tracker import bind_interrupt_request_tracker from ...opencode_upstream_client import UpstreamConcurrencyLimitError from ..dispatch import ExtensionHandlerContext from ..error_responses import ( @@ -81,18 +82,21 @@ async def handle_interrupt_callback_request( expected_interrupt_type = ( "permission" if base_request.method == context.method_reply_permission else "question" ) - resolve_request = getattr(context.upstream_client, "resolve_interrupt_request", None) - if callable(resolve_request): - status, binding = await resolve_request(request_id) - if status != "active" or binding is None: - return context.error_response( - base_request.id, - interrupt_not_found_error( - ERR_INTERRUPT_EXPIRED if status == "expired" else ERR_INTERRUPT_NOT_FOUND, - request_id=request_id, - expired=status == "expired", - ), - ) + interrupt_requests = bind_interrupt_request_tracker(context.upstream_client) + resolution = await interrupt_requests.resolve_request(request_id) + if resolution.status != "active": + return context.error_response( + base_request.id, + interrupt_not_found_error( + ERR_INTERRUPT_EXPIRED + if resolution.status == "expired" + else ERR_INTERRUPT_NOT_FOUND, + request_id=request_id, + expired=resolution.status == "expired", + ), + ) + binding = resolution.binding + if binding is not None: if binding.interrupt_type != expected_interrupt_type: return context.error_response( base_request.id, @@ -129,16 +133,6 @@ async def handle_interrupt_callback_request( request_id=request_id, ), ) - else: - resolve_session = getattr(context.upstream_client, "resolve_interrupt_session", None) - if callable(resolve_session) and not await resolve_session(request_id): - return context.error_response( - base_request.id, - interrupt_not_found_error( - ERR_INTERRUPT_NOT_FOUND, - request_id=request_id, - ), - ) if base_request.method == context.method_reply_permission: allowed_fields = {"request_id", "reply", "message", "metadata"} @@ -183,9 +177,7 @@ async def handle_interrupt_callback_request( request_id, **routing_kwargs, ) - discard_request = getattr(context.upstream_client, "discard_interrupt_request", None) - if callable(discard_request): - await discard_request(request_id) + await interrupt_requests.discard_request(request_id) except ValueError as exc: return context.error_response( base_request.id, @@ -194,9 +186,7 @@ async def handle_interrupt_callback_request( except httpx.HTTPStatusError as exc: upstream_status = exc.response.status_code if upstream_status == 404: - discard_request = getattr(context.upstream_client, "discard_interrupt_request", None) - if callable(discard_request): - await discard_request(request_id) + await interrupt_requests.discard_request(request_id) return context.error_response( base_request.id, interrupt_not_found_error( diff --git a/tests/execution/test_metrics.py b/tests/execution/test_metrics.py index 458c9f2..0f96ffe 100644 --- a/tests/execution/test_metrics.py +++ b/tests/execution/test_metrics.py @@ -145,6 +145,10 @@ async def remember_interrupt_request( del interrupt_type, identity, task_id, context_id, details, ttl_seconds self._interrupt_requests[request_id] = session_id + async def resolve_interrupt_request(self, request_id: str): + del request_id + return "missing", None + async def discard_interrupt_request(self, request_id: str) -> None: self._interrupt_requests.pop(request_id, None) diff --git a/tests/execution/test_multipart_input.py b/tests/execution/test_multipart_input.py index c7a5db4..aef749b 100644 --- a/tests/execution/test_multipart_input.py +++ b/tests/execution/test_multipart_input.py @@ -67,6 +67,10 @@ async def stream_events(self, stop_event=None, *, directory: str | None = None): async def remember_interrupt_request(self, **_kwargs) -> None: return None + async def resolve_interrupt_request(self, request_id: str): + del request_id + return "missing", None + async def resolve_interrupt_session(self, request_id: str) -> str | None: del request_id return None diff --git a/tests/execution/test_streaming_output_contract_interrupts.py b/tests/execution/test_streaming_output_contract_interrupts.py index 45a56cf..85fcd20 100644 --- a/tests/execution/test_streaming_output_contract_interrupts.py +++ b/tests/execution/test_streaming_output_contract_interrupts.py @@ -83,7 +83,7 @@ async def test_streaming_emits_interrupt_status_for_permission_asked_event() -> assert "metadata" not in interrupt["details"] assert "tool" not in interrupt["details"] assert interrupt_statuses[0].status.state == TaskState.TASK_STATE_INPUT_REQUIRED - assert client._interrupt_requests["perm-req-1"]["details"] == { + assert client._interrupt_request_details["perm-req-1"] == { "permission": "read", "patterns": ["/data/project/.env.secret"], } @@ -132,7 +132,7 @@ async def test_streaming_emits_interrupt_status_for_question_asked_event() -> No ] assert "tool" not in interrupt["details"] assert interrupt_statuses[0].status.state == TaskState.TASK_STATE_INPUT_REQUIRED - assert client._interrupt_requests["q-req-1"]["details"] == { + assert client._interrupt_request_details["q-req-1"] == { "questions": [ { "header": "Confirm", diff --git a/tests/jsonrpc/test_opencode_session_extension_interrupts.py b/tests/jsonrpc/test_opencode_session_extension_interrupts.py index 914e797..705d194 100644 --- a/tests/jsonrpc/test_opencode_session_extension_interrupts.py +++ b/tests/jsonrpc/test_opencode_session_extension_interrupts.py @@ -413,6 +413,86 @@ async def permission_reply( assert dummy.permission_reply_calls == [] +@pytest.mark.asyncio +async def test_interrupt_callback_extension_accepts_legacy_session_only_client(monkeypatch): + import opencode_a2a.server.application as app_module + + class LegacyInterruptClient: + def __init__(self, _settings: Settings, **kwargs) -> None: + del kwargs + self.settings = _settings + self.directory = None + self.stream_timeout = None + self.permission_reply_calls: list[dict] = [] + self._active_requests = {"perm-legacy"} + + async def close(self) -> None: + return None + + async def resolve_interrupt_session(self, request_id: str) -> str | None: + if request_id in self._active_requests: + return "ses-legacy" + return None + + async def permission_reply( + self, + request_id: str, + *, + reply: str, + message: str | None = None, + directory: str | None = None, + workspace_id: str | None = None, + ) -> bool: + self.permission_reply_calls.append( + { + "request_id": request_id, + "reply": reply, + "message": message, + "directory": directory, + "workspace_id": workspace_id, + } + ) + return True + + async def discard_interrupt_request(self, request_id: str) -> None: + self._active_requests.discard(request_id) + + dummy = LegacyInterruptClient( + make_settings(test_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", lambda _settings, **_kwargs: dummy) + app = app_module.create_app( + make_settings(test_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + headers = _extension_headers({"Authorization": "Bearer t-1"}) + resp = await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 161, + "method": "a2a.interrupt.permission.reply", + "params": {"request_id": "perm-legacy", "reply": "once"}, + }, + ) + payload = resp.json() + assert payload.get("error") is None + assert payload["result"] == {"ok": True, "request_id": "perm-legacy"} + assert dummy.permission_reply_calls == [ + { + "request_id": "perm-legacy", + "reply": "once", + "message": None, + "directory": None, + "workspace_id": None, + } + ] + assert "perm-legacy" not in dummy._active_requests + + @pytest.mark.asyncio async def test_interrupt_callback_extension_rejects_interrupt_type_mismatch(monkeypatch): import opencode_a2a.server.application as app_module diff --git a/tests/support/helpers.py b/tests/support/helpers.py index c79f883..28a62c4 100644 --- a/tests/support/helpers.py +++ b/tests/support/helpers.py @@ -17,6 +17,7 @@ from opencode_a2a.opencode_upstream_client import OpencodeMessage from opencode_a2a.server.context_helpers import normalize_server_call_context from tests.support import settings as test_settings +from tests.support.interrupt_clients import InterruptRequestClientMixin def make_basic_auth_header(username: str, password: str) -> dict[str, str]: @@ -160,7 +161,7 @@ def make_request_context_with_parts( ) -class DummyChatOpencodeUpstreamClient: +class DummyChatOpencodeUpstreamClient(InterruptRequestClientMixin): def __init__( self, settings: Settings | None = None, @@ -178,6 +179,8 @@ def __init__( self.settings = settings or test_settings.make_settings( opencode_base_url="http://localhost" ) + self._interrupt_requests: dict[str, dict[str, str | None]] = {} + self._interrupt_request_details: dict[str, dict[str, Any] | None] = {} async def close(self) -> None: return None @@ -226,35 +229,3 @@ async def stream_events( # noqa: ANN001 del stop_event, directory, workspace_id for _ in (): yield {} - - async def remember_interrupt_request( - self, - *, - request_id: str, - session_id: str, - interrupt_type: str | None = None, - identity: str | None = None, - credential_id: str | None = None, - task_id: str | None = None, - context_id: str | None = None, - details: dict[str, Any] | None = None, - ttl_seconds: float | None = None, - ) -> None: - del ( - request_id, - session_id, - interrupt_type, - identity, - credential_id, - task_id, - context_id, - details, - ttl_seconds, - ) - - async def resolve_interrupt_session(self, request_id: str) -> str | None: - del request_id - return None - - async def discard_interrupt_request(self, request_id: str) -> None: - del request_id diff --git a/tests/support/streaming_output.py b/tests/support/streaming_output.py index 75cd8b5..a84df73 100644 --- a/tests/support/streaming_output.py +++ b/tests/support/streaming_output.py @@ -10,10 +10,11 @@ from tests.support.helpers import ( DummyEventQueue, ) +from tests.support.interrupt_clients import InterruptRequestClientMixin from tests.support.settings import make_settings -class DummyStreamingClient: +class DummyStreamingClient(InterruptRequestClientMixin): def __init__( self, *, @@ -36,8 +37,8 @@ def __init__( self.max_in_flight_send = 0 self.stream_timeout = None self.directory = None - self._interrupt_sessions: dict[str, str] = {} - self._interrupt_requests: dict[str, dict] = {} + self._interrupt_requests: dict[str, dict[str, str | None]] = {} + self._interrupt_request_details: dict[str, dict | None] = {} self.settings = make_settings( test_bearer_token="test", opencode_base_url="http://localhost", @@ -90,36 +91,6 @@ async def stream_events(self, stop_event=None, *, directory: str | None = None): ): yield {"type": "session.idle", "properties": {"sessionID": "ses-1"}} - async def remember_interrupt_request( - self, - *, - request_id: str, - session_id: str, - interrupt_type: str | None = None, - identity: str | None = None, - task_id: str | None = None, - context_id: str | None = None, - details: dict | None = None, - ttl_seconds: float | None = None, - ) -> None: - del ttl_seconds - self._interrupt_sessions[request_id] = session_id - self._interrupt_requests[request_id] = { - "session_id": session_id, - "interrupt_type": interrupt_type, - "identity": identity, - "task_id": task_id, - "context_id": context_id, - "details": details, - } - - async def resolve_interrupt_session(self, request_id: str) -> str | None: - return self._interrupt_sessions.get(request_id) - - async def discard_interrupt_request(self, request_id: str) -> None: - self._interrupt_sessions.pop(request_id, None) - self._interrupt_requests.pop(request_id, None) - def _event( *, From 8b32cd5c6bdcd737a48cf5341953b37629404203 Mon Sep 17 00:00:00 2001 From: "helen@cloud" Date: Wed, 6 May 2026 21:56:06 +0800 Subject: [PATCH 2/2] refactor: drop interrupt tracker forwarding helper --- src/opencode_a2a/execution/stream_runtime.py | 4 ++-- src/opencode_a2a/interrupt_request_tracker.py | 4 ---- src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py | 4 ++-- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/opencode_a2a/execution/stream_runtime.py b/src/opencode_a2a/execution/stream_runtime.py index 950abc0..54e7605 100644 --- a/src/opencode_a2a/execution/stream_runtime.py +++ b/src/opencode_a2a/execution/stream_runtime.py @@ -16,7 +16,7 @@ ) from ..a2a_utils import make_data_part -from ..interrupt_request_tracker import bind_interrupt_request_tracker +from ..interrupt_request_tracker import BoundInterruptRequestTracker from .event_helpers import _enqueue_artifact_update from .stream_events import ( BlockType, @@ -86,7 +86,7 @@ async def consume( pending_deltas: defaultdict[str, list[_PendingDelta]] = defaultdict(list) backoff = 0.5 max_backoff = 5.0 - interrupt_requests = bind_interrupt_request_tracker(self._client) + interrupt_requests = BoundInterruptRequestTracker(self._client) async def _emit_chunks(chunks: list[_NormalizedStreamChunk]) -> None: for chunk in chunks: diff --git a/src/opencode_a2a/interrupt_request_tracker.py b/src/opencode_a2a/interrupt_request_tracker.py index 74dbf31..cbb04b7 100644 --- a/src/opencode_a2a/interrupt_request_tracker.py +++ b/src/opencode_a2a/interrupt_request_tracker.py @@ -73,7 +73,3 @@ async def discard_request(self, request_id: str) -> None: if self._discard_request is None: return await self._discard_request(request_id) - - -def bind_interrupt_request_tracker(client: object) -> BoundInterruptRequestTracker: - return BoundInterruptRequestTracker(client) diff --git a/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py b/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py index 7e9e159..04b6f03 100644 --- a/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py +++ b/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py @@ -8,7 +8,7 @@ from starlette.responses import Response from ...contracts.extensions import INTERRUPT_ERROR_BUSINESS_CODES -from ...interrupt_request_tracker import bind_interrupt_request_tracker +from ...interrupt_request_tracker import BoundInterruptRequestTracker from ...opencode_upstream_client import UpstreamConcurrencyLimitError from ..dispatch import ExtensionHandlerContext from ..error_responses import ( @@ -82,7 +82,7 @@ async def handle_interrupt_callback_request( expected_interrupt_type = ( "permission" if base_request.method == context.method_reply_permission else "question" ) - interrupt_requests = bind_interrupt_request_tracker(context.upstream_client) + interrupt_requests = BoundInterruptRequestTracker(context.upstream_client) resolution = await interrupt_requests.resolve_request(request_id) if resolution.status != "active": return context.error_response(