Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 11 additions & 21 deletions src/opencode_a2a/execution/stream_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)

from ..a2a_utils import make_data_part
from ..interrupt_request_tracker import BoundInterruptRequestTracker
from .event_helpers import _enqueue_artifact_update
from .stream_events import (
BlockType,
Expand Down Expand Up @@ -85,6 +86,7 @@ async def consume(
pending_deltas: defaultdict[str, list[_PendingDelta]] = defaultdict(list)
backoff = 0.5
max_backoff = 5.0
interrupt_requests = BoundInterruptRequestTracker(self._client)

async def _emit_chunks(chunks: list[_NormalizedStreamChunk]) -> None:
for chunk in chunks:
Expand Down Expand Up @@ -490,21 +492,15 @@ def _tool_chunks(
if asked is not None:
request_id = asked["request_id"]
if stream_state.mark_interrupt_pending(request_id):
remember_request = getattr(
self._client,
"remember_interrupt_request",
None,
await interrupt_requests.remember_request(
request_id=request_id,
session_id=session_id,
interrupt_type=asked["interrupt_type"],
identity=identity,
task_id=task_id,
context_id=context_id,
details=asked["details"],
)
if callable(remember_request):
await remember_request(
request_id=request_id,
session_id=session_id,
interrupt_type=asked["interrupt_type"],
identity=identity,
task_id=task_id,
context_id=context_id,
details=asked["details"],
)
await _emit_interrupt_status(
state=TaskState.TASK_STATE_INPUT_REQUIRED,
request_id=request_id,
Expand All @@ -518,13 +514,7 @@ def _tool_chunks(
cleared_pending = stream_state.clear_interrupt_pending(
resolved_request_id
)
discard_request = getattr(
self._client,
"discard_interrupt_request",
None,
)
if callable(discard_request):
await discard_request(resolved_request_id)
await interrupt_requests.discard_request(resolved_request_id)
if cleared_pending:
await _emit_interrupt_status(
state=TaskState.TASK_STATE_WORKING,
Expand Down
75 changes: 75 additions & 0 deletions src/opencode_a2a/interrupt_request_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Literal

from .runtime_state import InterruptRequestBinding


@dataclass(frozen=True)
class InterruptRequestResolution:
status: Literal["active", "expired", "missing"]
binding: InterruptRequestBinding | None


class BoundInterruptRequestTracker:
def __init__(self, client: object) -> None:
remember_request = getattr(client, "remember_interrupt_request", None)
resolve_request = getattr(client, "resolve_interrupt_request", None)
resolve_session = getattr(client, "resolve_interrupt_session", None)
discard_request = getattr(client, "discard_interrupt_request", None)

self._remember_request = remember_request if callable(remember_request) else None
self._resolve_request = resolve_request if callable(resolve_request) else None
self._resolve_session = resolve_session if callable(resolve_session) else None
self._discard_request = discard_request if callable(discard_request) else None

async def remember_request(
self,
*,
request_id: str,
session_id: str,
interrupt_type: str,
identity: str | None = None,
credential_id: str | None = None,
task_id: str | None = None,
context_id: str | None = None,
details: dict[str, Any] | None = None,
ttl_seconds: float | None = None,
) -> None:
if self._remember_request is None:
return
request_kwargs: dict[str, Any] = {
"request_id": request_id,
"session_id": session_id,
"interrupt_type": interrupt_type,
}
if identity is not None:
request_kwargs["identity"] = identity
if credential_id is not None:
request_kwargs["credential_id"] = credential_id
if task_id is not None:
request_kwargs["task_id"] = task_id
if context_id is not None:
request_kwargs["context_id"] = context_id
if details is not None:
request_kwargs["details"] = details
if ttl_seconds is not None:
request_kwargs["ttl_seconds"] = ttl_seconds
await self._remember_request(**request_kwargs)

async def resolve_request(self, request_id: str) -> InterruptRequestResolution:
if self._resolve_request is not None:
status, binding = await self._resolve_request(request_id)
return InterruptRequestResolution(status=status, binding=binding)
if self._resolve_session is not None:
# Keep session-only clients working while newer call sites prefer bindings.
session_id = await self._resolve_session(request_id)
if isinstance(session_id, str) and session_id:
return InterruptRequestResolution(status="active", binding=None)
return InterruptRequestResolution(status="missing", binding=None)

async def discard_request(self, request_id: str) -> None:
if self._discard_request is None:
return
await self._discard_request(request_id)
46 changes: 18 additions & 28 deletions src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from starlette.responses import Response

from ...contracts.extensions import INTERRUPT_ERROR_BUSINESS_CODES
from ...interrupt_request_tracker import BoundInterruptRequestTracker
from ...opencode_upstream_client import UpstreamConcurrencyLimitError
from ..dispatch import ExtensionHandlerContext
from ..error_responses import (
Expand Down Expand Up @@ -81,18 +82,21 @@ async def handle_interrupt_callback_request(
expected_interrupt_type = (
"permission" if base_request.method == context.method_reply_permission else "question"
)
resolve_request = getattr(context.upstream_client, "resolve_interrupt_request", None)
if callable(resolve_request):
status, binding = await resolve_request(request_id)
if status != "active" or binding is None:
return context.error_response(
base_request.id,
interrupt_not_found_error(
ERR_INTERRUPT_EXPIRED if status == "expired" else ERR_INTERRUPT_NOT_FOUND,
request_id=request_id,
expired=status == "expired",
),
)
interrupt_requests = BoundInterruptRequestTracker(context.upstream_client)
resolution = await interrupt_requests.resolve_request(request_id)
if resolution.status != "active":
return context.error_response(
base_request.id,
interrupt_not_found_error(
ERR_INTERRUPT_EXPIRED
if resolution.status == "expired"
else ERR_INTERRUPT_NOT_FOUND,
request_id=request_id,
expired=resolution.status == "expired",
),
)
binding = resolution.binding
if binding is not None:
if binding.interrupt_type != expected_interrupt_type:
return context.error_response(
base_request.id,
Expand Down Expand Up @@ -129,16 +133,6 @@ async def handle_interrupt_callback_request(
request_id=request_id,
),
)
else:
resolve_session = getattr(context.upstream_client, "resolve_interrupt_session", None)
if callable(resolve_session) and not await resolve_session(request_id):
return context.error_response(
base_request.id,
interrupt_not_found_error(
ERR_INTERRUPT_NOT_FOUND,
request_id=request_id,
),
)

if base_request.method == context.method_reply_permission:
allowed_fields = {"request_id", "reply", "message", "metadata"}
Expand Down Expand Up @@ -183,9 +177,7 @@ async def handle_interrupt_callback_request(
request_id,
**routing_kwargs,
)
discard_request = getattr(context.upstream_client, "discard_interrupt_request", None)
if callable(discard_request):
await discard_request(request_id)
await interrupt_requests.discard_request(request_id)
except ValueError as exc:
return context.error_response(
base_request.id,
Expand All @@ -194,9 +186,7 @@ async def handle_interrupt_callback_request(
except httpx.HTTPStatusError as exc:
upstream_status = exc.response.status_code
if upstream_status == 404:
discard_request = getattr(context.upstream_client, "discard_interrupt_request", None)
if callable(discard_request):
await discard_request(request_id)
await interrupt_requests.discard_request(request_id)
return context.error_response(
base_request.id,
interrupt_not_found_error(
Expand Down
4 changes: 4 additions & 0 deletions tests/execution/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ async def remember_interrupt_request(
del interrupt_type, identity, task_id, context_id, details, ttl_seconds
self._interrupt_requests[request_id] = session_id

async def resolve_interrupt_request(self, request_id: str):
del request_id
return "missing", None

async def discard_interrupt_request(self, request_id: str) -> None:
self._interrupt_requests.pop(request_id, None)

Expand Down
4 changes: 4 additions & 0 deletions tests/execution/test_multipart_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ async def stream_events(self, stop_event=None, *, directory: str | None = None):
async def remember_interrupt_request(self, **_kwargs) -> None:
return None

async def resolve_interrupt_request(self, request_id: str):
del request_id
return "missing", None

async def resolve_interrupt_session(self, request_id: str) -> str | None:
del request_id
return None
Expand Down
4 changes: 2 additions & 2 deletions tests/execution/test_streaming_output_contract_interrupts.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def test_streaming_emits_interrupt_status_for_permission_asked_event() ->
assert "metadata" not in interrupt["details"]
assert "tool" not in interrupt["details"]
assert interrupt_statuses[0].status.state == TaskState.TASK_STATE_INPUT_REQUIRED
assert client._interrupt_requests["perm-req-1"]["details"] == {
assert client._interrupt_request_details["perm-req-1"] == {
"permission": "read",
"patterns": ["/data/project/.env.secret"],
}
Expand Down Expand Up @@ -132,7 +132,7 @@ async def test_streaming_emits_interrupt_status_for_question_asked_event() -> No
]
assert "tool" not in interrupt["details"]
assert interrupt_statuses[0].status.state == TaskState.TASK_STATE_INPUT_REQUIRED
assert client._interrupt_requests["q-req-1"]["details"] == {
assert client._interrupt_request_details["q-req-1"] == {
"questions": [
{
"header": "Confirm",
Expand Down
80 changes: 80 additions & 0 deletions tests/jsonrpc/test_opencode_session_extension_interrupts.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,86 @@ async def permission_reply(
assert dummy.permission_reply_calls == []


@pytest.mark.asyncio
async def test_interrupt_callback_extension_accepts_legacy_session_only_client(monkeypatch):
import opencode_a2a.server.application as app_module

class LegacyInterruptClient:
def __init__(self, _settings: Settings, **kwargs) -> None:
del kwargs
self.settings = _settings
self.directory = None
self.stream_timeout = None
self.permission_reply_calls: list[dict] = []
self._active_requests = {"perm-legacy"}

async def close(self) -> None:
return None

async def resolve_interrupt_session(self, request_id: str) -> str | None:
if request_id in self._active_requests:
return "ses-legacy"
return None

async def permission_reply(
self,
request_id: str,
*,
reply: str,
message: str | None = None,
directory: str | None = None,
workspace_id: str | None = None,
) -> bool:
self.permission_reply_calls.append(
{
"request_id": request_id,
"reply": reply,
"message": message,
"directory": directory,
"workspace_id": workspace_id,
}
)
return True

async def discard_interrupt_request(self, request_id: str) -> None:
self._active_requests.discard(request_id)

dummy = LegacyInterruptClient(
make_settings(test_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS)
)
monkeypatch.setattr(app_module, "OpencodeUpstreamClient", lambda _settings, **_kwargs: dummy)
app = app_module.create_app(
make_settings(test_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS)
)

transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
headers = _extension_headers({"Authorization": "Bearer t-1"})
resp = await client.post(
"/",
headers=headers,
json={
"jsonrpc": "2.0",
"id": 161,
"method": "a2a.interrupt.permission.reply",
"params": {"request_id": "perm-legacy", "reply": "once"},
},
)
payload = resp.json()
assert payload.get("error") is None
assert payload["result"] == {"ok": True, "request_id": "perm-legacy"}
assert dummy.permission_reply_calls == [
{
"request_id": "perm-legacy",
"reply": "once",
"message": None,
"directory": None,
"workspace_id": None,
}
]
assert "perm-legacy" not in dummy._active_requests


@pytest.mark.asyncio
async def test_interrupt_callback_extension_rejects_interrupt_type_mismatch(monkeypatch):
import opencode_a2a.server.application as app_module
Expand Down
Loading
Loading