From 45902c5892776868aedb00ae66309b57ddbfc980 Mon Sep 17 00:00:00 2001 From: "helen@cloud" Date: Wed, 6 May 2026 20:27:39 +0800 Subject: [PATCH 1/2] test: cover parsing edges and split test upstream fakes --- src/opencode_a2a/client/error_mapping.py | 5 +- tests/client/test_error_mapping.py | 171 ++++- .../test_extension_contract_consistency.py | 4 +- tests/jsonrpc/test_application_dispatch.py | 2 +- tests/jsonrpc/test_dispatch_registry.py | 2 +- ...est_opencode_session_extension_commands.py | 6 +- ...t_opencode_session_extension_interrupts.py | 6 +- ...st_opencode_session_extension_lifecycle.py | 6 +- ...opencode_session_extension_prompt_async.py | 6 +- ...test_opencode_session_extension_queries.py | 6 +- ...st_opencode_workspace_control_extension.py | 6 +- tests/server/test_cli.py | 178 +++++ tests/server/test_transport_contract.py | 2 +- tests/support/helpers.py | 631 +----------------- tests/support/interrupt_clients.py | 141 ++++ tests/support/session_query_client.py | 450 +++++++++++++ tests/support/workspace_control_client.py | 65 ++ tests/test_parsing.py | 126 ++++ 18 files changed, 1155 insertions(+), 658 deletions(-) create mode 100644 tests/support/interrupt_clients.py create mode 100644 tests/support/session_query_client.py create mode 100644 tests/support/workspace_control_client.py create mode 100644 tests/test_parsing.py diff --git a/src/opencode_a2a/client/error_mapping.py b/src/opencode_a2a/client/error_mapping.py index 52bac2c..53c402e 100644 --- a/src/opencode_a2a/client/error_mapping.py +++ b/src/opencode_a2a/client/error_mapping.py @@ -238,10 +238,7 @@ def map_agent_card_error( ) -> A2AClientError: if isinstance(exc, AgentCardResolutionError): if exc.status_code is not None: - return _attach_http_status( - map_client_error("agent-card/fetch", SDKClientError(str(exc))), - exc.status_code, - ) + return map_http_error("agent-card/fetch", exc) return A2APeerProtocolError( "Remote A2A peer returned an invalid agent card payload", error_code="invalid_agent_card", diff --git a/tests/client/test_error_mapping.py b/tests/client/test_error_mapping.py index d2039c0..20a94ca 100644 --- a/tests/client/test_error_mapping.py +++ b/tests/client/test_error_mapping.py @@ -1,9 +1,26 @@ from __future__ import annotations import httpx +import pytest +from a2a.client.errors import ( + A2AClientError as SDKClientError, +) +from a2a.client.errors import ( + A2AClientTimeoutError, + AgentCardResolutionError, +) +from a2a.utils.errors import ( + A2AError, + InvalidParamsError, + MethodNotFoundError, + TaskNotFoundError, + VersionNotSupportedError, +) from opencode_a2a.client.error_mapping import ( + map_a2a_error, map_agent_card_error, + map_client_error, map_http_error, map_jsonrpc_error, map_operation_error, @@ -25,10 +42,129 @@ ) +@pytest.mark.parametrize( + ("exc", "expected_type", "error_code", "data"), + [ + pytest.param( + TaskNotFoundError("missing"), + A2AUnsupportedOperationError, + "task_not_found", + None, + id="task-not-found", + ), + pytest.param( + MethodNotFoundError("unsupported"), + A2AUnsupportedOperationError, + "method_not_supported", + None, + id="method-not-found", + ), + pytest.param( + InvalidParamsError("bad", data={"field": "limit"}), + A2APeerProtocolError, + "invalid_params", + {"field": "limit"}, + id="invalid-params", + ), + pytest.param( + VersionNotSupportedError("bad version", data={"version": "2.0"}), + A2AUnsupportedOperationError, + "version_not_supported", + {"version": "2.0"}, + id="unsupported-version", + ), + pytest.param( + A2AError("generic", data={"detail": "boom"}), + A2APeerProtocolError, + "peer_protocol_error", + {"detail": "boom"}, + id="generic-a2a-error", + ), + ], +) +def test_map_a2a_error_variants( + exc: A2AError, + expected_type: type[Exception], + error_code: str, + data: object | None, +) -> None: + mapped = map_a2a_error(exc) + + assert isinstance(mapped, expected_type) + assert mapped.error_code == error_code + assert mapped.data == data + + +@pytest.mark.parametrize( + ("exc", "expected_type", "http_status"), + [ + pytest.param( + FakeA2AClientHTTPError(401, "denied"), + A2AAuthenticationError, + 401, + id="401", + ), + pytest.param( + FakeA2AClientHTTPError(403, "forbidden"), + A2APermissionDeniedError, + 403, + id="403", + ), + pytest.param( + FakeA2AClientHTTPError(404, "missing"), + A2AUnsupportedOperationError, + 404, + id="404", + ), + pytest.param( + FakeA2AClientHTTPError(408, "slow"), + A2ATimeoutError, + 408, + id="408", + ), + pytest.param( + SDKClientError("HTTP Error 503: busy"), + A2AClientResetRequiredError, + 503, + id="503-from-message", + ), + pytest.param( + FakeA2AClientHTTPError(500, "boom"), + A2AAgentUnavailableError, + 500, + id="500", + ), + ], +) +def test_map_client_error_http_variants( + exc: SDKClientError, + expected_type: type[Exception], + http_status: int, +) -> None: + mapped = map_client_error("SendMessage", exc) + + assert isinstance(mapped, expected_type) + assert mapped.http_status == http_status + + +def test_map_client_error_timeout_variant() -> None: + mapped = map_client_error("SendMessage", A2AClientTimeoutError("timed out")) + + assert isinstance(mapped, A2ATimeoutError) + assert mapped.http_status is None + + +def test_map_client_error_without_status_returns_protocol_error() -> None: + mapped = map_client_error("SendMessage", SDKClientError("broken client")) + + assert isinstance(mapped, A2APeerProtocolError) + assert mapped.error_code == "invalid_client_error" + + def test_map_jsonrpc_error_variants() -> None: invalid_params_error = FakeA2AClientJSONRPCError( JSONRPCErrorResponse( - error=JSONRPCError(code=-32602, message="bad params"), + error=JSONRPCError(code=-32602, message="bad params", data={"field": "limit"}), id="req-1", ) ) @@ -51,9 +187,12 @@ def test_map_jsonrpc_error_variants() -> None: assert isinstance(mapped_invalid, A2APeerProtocolError) assert mapped_invalid.error_code == "invalid_params" + assert mapped_invalid.code == -32602 + assert mapped_invalid.data == {"field": "limit"} assert isinstance(mapped_internal, A2AClientResetRequiredError) assert isinstance(mapped_generic, A2APeerProtocolError) assert mapped_generic.error_code == "peer_protocol_error" + assert mapped_generic.data is None def test_map_http_error_variants() -> None: @@ -83,3 +222,33 @@ def test_map_agent_card_error_json_variant() -> None: assert isinstance(mapped, A2APeerProtocolError) assert mapped.error_code == "invalid_agent_card" + + +def test_map_agent_card_error_resolution_error_without_status_is_invalid_card() -> None: + mapped = map_agent_card_error(AgentCardResolutionError("invalid json")) + + assert isinstance(mapped, A2APeerProtocolError) + assert mapped.error_code == "invalid_agent_card" + + +def test_map_agent_card_error_resolution_error_with_status_uses_http_mapping() -> None: + mapped = map_agent_card_error(AgentCardResolutionError("forbidden", status_code=403)) + + assert isinstance(mapped, A2APermissionDeniedError) + assert mapped.http_status == 403 + + +@pytest.mark.parametrize( + ("exc", "expected_type"), + [ + pytest.param(httpx.ReadTimeout("timed out"), A2ATimeoutError, id="timeout"), + pytest.param(httpx.ConnectError("down"), A2AAgentUnavailableError, id="transport"), + ], +) +def test_map_agent_card_error_transport_variants( + exc: httpx.TimeoutException | httpx.TransportError, + expected_type: type[Exception], +) -> None: + mapped = map_agent_card_error(exc) + + assert isinstance(mapped, expected_type) diff --git a/tests/contracts/test_extension_contract_consistency.py b/tests/contracts/test_extension_contract_consistency.py index f1d6dd3..0a2835c 100644 --- a/tests/contracts/test_extension_contract_consistency.py +++ b/tests/contracts/test_extension_contract_consistency.py @@ -44,10 +44,10 @@ from opencode_a2a.protocol_versions import A2A_PROTOCOL_VERSION from opencode_a2a.server.agent_card import build_agent_card from opencode_a2a.server.application import create_app -from tests.support.helpers import ( +from tests.support.session_extensions import _extension_headers +from tests.support.session_query_client import ( DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, ) -from tests.support.session_extensions import _extension_headers from tests.support.settings import make_settings diff --git a/tests/jsonrpc/test_application_dispatch.py b/tests/jsonrpc/test_application_dispatch.py index 6d6b34f..3d2c99c 100644 --- a/tests/jsonrpc/test_application_dispatch.py +++ b/tests/jsonrpc/test_application_dispatch.py @@ -14,8 +14,8 @@ from opencode_a2a.contracts.extensions import SESSION_MANAGEMENT_EXTENSION_URI from opencode_a2a.jsonrpc.models import JSONRPCRequest from tests.support.async_iterators import iter_async -from tests.support.helpers import DummySessionQueryOpencodeUpstreamClient from tests.support.session_extensions import _BASE_SETTINGS, _jsonrpc_app +from tests.support.session_query_client import DummySessionQueryOpencodeUpstreamClient from tests.support.settings import make_settings diff --git a/tests/jsonrpc/test_dispatch_registry.py b/tests/jsonrpc/test_dispatch_registry.py index 700d9f1..b2991a6 100644 --- a/tests/jsonrpc/test_dispatch_registry.py +++ b/tests/jsonrpc/test_dispatch_registry.py @@ -6,9 +6,9 @@ from opencode_a2a.a2a_protocol import CORE_JSONRPC_METHODS from opencode_a2a.contracts.extensions import SESSION_MANAGEMENT_EXTENSION_URI from opencode_a2a.jsonrpc.application import OpencodeSessionManagementJSONRPCApplication -from tests.support.helpers import DummySessionQueryOpencodeUpstreamClient from tests.support.jsonrpc_error_assertions import assert_v1_error_reason, error_context_detail from tests.support.session_extensions import _BASE_SETTINGS, _extension_headers, _jsonrpc_app +from tests.support.session_query_client import DummySessionQueryOpencodeUpstreamClient from tests.support.settings import make_settings diff --git a/tests/jsonrpc/test_opencode_session_extension_commands.py b/tests/jsonrpc/test_opencode_session_extension_commands.py index 6c6a646..dfb4782 100644 --- a/tests/jsonrpc/test_opencode_session_extension_commands.py +++ b/tests/jsonrpc/test_opencode_session_extension_commands.py @@ -1,9 +1,6 @@ import httpx import pytest -from tests.support.helpers import ( - DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, -) from tests.support.helpers import make_basic_auth_header from tests.support.jsonrpc_error_assertions import ( assert_v1_error_context, @@ -16,6 +13,9 @@ _jsonrpc_app, _session_meta, ) +from tests.support.session_query_client import ( + DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, +) from tests.support.settings import make_settings diff --git a/tests/jsonrpc/test_opencode_session_extension_interrupts.py b/tests/jsonrpc/test_opencode_session_extension_interrupts.py index afda4e9..914e797 100644 --- a/tests/jsonrpc/test_opencode_session_extension_interrupts.py +++ b/tests/jsonrpc/test_opencode_session_extension_interrupts.py @@ -3,14 +3,14 @@ from opencode_a2a.config import Settings from opencode_a2a.opencode_upstream_client import UpstreamConcurrencyLimitError -from tests.support.helpers import ( - DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, -) from tests.support.jsonrpc_error_assertions import ( assert_v1_error_reason, error_context_detail, ) from tests.support.session_extensions import _BASE_SETTINGS, _extension_headers +from tests.support.session_query_client import ( + DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, +) from tests.support.settings import make_settings diff --git a/tests/jsonrpc/test_opencode_session_extension_lifecycle.py b/tests/jsonrpc/test_opencode_session_extension_lifecycle.py index df04466..04a639e 100644 --- a/tests/jsonrpc/test_opencode_session_extension_lifecycle.py +++ b/tests/jsonrpc/test_opencode_session_extension_lifecycle.py @@ -3,9 +3,6 @@ import httpx import pytest -from tests.support.helpers import ( - DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, -) from tests.support.jsonrpc_error_assertions import assert_v1_error_reason from tests.support.session_extensions import ( _BASE_SETTINGS, @@ -13,6 +10,9 @@ _jsonrpc_app, _session_meta, ) +from tests.support.session_query_client import ( + DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, +) from tests.support.settings import make_settings diff --git a/tests/jsonrpc/test_opencode_session_extension_prompt_async.py b/tests/jsonrpc/test_opencode_session_extension_prompt_async.py index e7dfe01..01da2c7 100644 --- a/tests/jsonrpc/test_opencode_session_extension_prompt_async.py +++ b/tests/jsonrpc/test_opencode_session_extension_prompt_async.py @@ -7,15 +7,15 @@ UpstreamConcurrencyLimitError, UpstreamContractError, ) -from tests.support.helpers import ( - DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, -) from tests.support.jsonrpc_error_assertions import ( assert_v1_error_metadata_contains, assert_v1_error_reason, error_context_detail, ) from tests.support.session_extensions import _BASE_SETTINGS, _extension_headers, _jsonrpc_app +from tests.support.session_query_client import ( + DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, +) from tests.support.settings import make_settings diff --git a/tests/jsonrpc/test_opencode_session_extension_queries.py b/tests/jsonrpc/test_opencode_session_extension_queries.py index 00b1a56..9246bb6 100644 --- a/tests/jsonrpc/test_opencode_session_extension_queries.py +++ b/tests/jsonrpc/test_opencode_session_extension_queries.py @@ -9,14 +9,14 @@ SESSION_QUERY_MAX_LIMIT, ) from opencode_a2a.opencode_upstream_client import UpstreamConcurrencyLimitError -from tests.support.helpers import ( - DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, -) from tests.support.jsonrpc_error_assertions import ( assert_v1_error_reason, error_context_detail, ) from tests.support.session_extensions import _BASE_SETTINGS, _extension_headers, _session_meta +from tests.support.session_query_client import ( + DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, +) from tests.support.settings import make_settings diff --git a/tests/jsonrpc/test_opencode_workspace_control_extension.py b/tests/jsonrpc/test_opencode_workspace_control_extension.py index 9e3c3ea..6395986 100644 --- a/tests/jsonrpc/test_opencode_workspace_control_extension.py +++ b/tests/jsonrpc/test_opencode_workspace_control_extension.py @@ -1,15 +1,15 @@ import httpx import pytest -from tests.support.helpers import ( - DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, -) from tests.support.helpers import make_basic_auth_header from tests.support.jsonrpc_error_assertions import ( assert_v1_error_metadata_contains, assert_v1_error_reason, ) from tests.support.session_extensions import _BASE_SETTINGS, _extension_headers +from tests.support.session_query_client import ( + DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, +) from tests.support.settings import make_settings diff --git a/tests/server/test_cli.py b/tests/server/test_cli.py index 371baaf..3cbf07a 100644 --- a/tests/server/test_cli.py +++ b/tests/server/test_cli.py @@ -1,12 +1,113 @@ from __future__ import annotations +from collections.abc import AsyncIterator +from dataclasses import dataclass from unittest import mock import pytest +from a2a.types import TaskState +from pydantic import BaseModel, ValidationError, field_validator from opencode_a2a import __version__, cli +@dataclass +class _FakeTextPart: + text: str + + def HasField(self, name: str) -> bool: + return name == "text" + + +@dataclass +class _FakeMessage: + parts: list[_FakeTextPart] + + +@dataclass +class _FakeArtifact: + artifact_id: str + parts: list[_FakeTextPart] + + +@dataclass +class _FakeArtifactUpdate: + artifact: _FakeArtifact + append: bool + + +@dataclass +class _FakeStatus: + state: TaskState + message: str | None = None + + +@dataclass +class _FakeStatusUpdate: + status: _FakeStatus | None + + +class _FakeEvent: + def __init__( + self, + *, + message: _FakeMessage | None = None, + artifact_update: _FakeArtifactUpdate | None = None, + status_update: _FakeStatusUpdate | None = None, + ) -> None: + self.message = message + self.artifact_update = artifact_update + self.status_update = status_update + + def HasField(self, name: str) -> bool: + return getattr(self, name) is not None + + +def _message_event(*parts: str) -> _FakeEvent: + return _FakeEvent(message=_FakeMessage(parts=[_FakeTextPart(text=part) for part in parts])) + + +def _artifact_event(artifact_id: str, text: str, *, append: bool = False) -> _FakeEvent: + return _FakeEvent( + artifact_update=_FakeArtifactUpdate( + artifact=_FakeArtifact(artifact_id=artifact_id, parts=[_FakeTextPart(text=text)]), + append=append, + ) + ) + + +def _failed_status_event(message: str) -> _FakeEvent: + return _FakeEvent( + status_update=_FakeStatusUpdate( + status=_FakeStatus(state=TaskState.TASK_STATE_FAILED, message=message) + ) + ) + + +class _FakeA2AClient: + def __init__( + self, + _agent_url: str, + *, + settings: object, + events: list[_FakeEvent] | None = None, + error: Exception | None = None, + ) -> None: + self.settings = settings + self._events = events or [] + self._error = error + self.closed = False + + async def send_message(self, _text: str) -> AsyncIterator[_FakeEvent]: + for event in self._events: + yield event + if self._error is not None: + raise self._error + + async def close(self) -> None: + self.closed = True + + def test_cli_help_does_not_require_runtime_settings(capsys: pytest.CaptureFixture[str]) -> None: with mock.patch("opencode_a2a.cli.serve_main") as serve_mock: with pytest.raises(SystemExit) as excinfo: @@ -132,3 +233,80 @@ def test_cli_call_rejects_basic_flag() -> None: parser.parse_args(["call", "http://agent.example.com", "hello", "--basic", "user:pass"]) assert excinfo.value.code == 2 + + +class _DemoSettingsModel(BaseModel): + token: str + + @field_validator("token") + @classmethod + def _validate_token(cls, value: str) -> str: + raise ValueError("missing token") + + +def test_validate_serve_configuration_formats_validation_errors() -> None: + with pytest.raises(ValidationError) as excinfo: + _DemoSettingsModel(token="placeholder") + + with mock.patch("opencode_a2a.cli.Settings", side_effect=excinfo.value): + assert cli.validate_serve_configuration() == ["token: missing token"] + + +@pytest.mark.asyncio +async def test_run_call_renders_incremental_artifacts_without_duplication( + capsys: pytest.CaptureFixture[str], +) -> None: + settings = object() + fake_client = _FakeA2AClient( + "http://agent.example.com", + settings=settings, + events=[ + _message_event("hello "), + _artifact_event("artifact-1", "abc"), + _artifact_event("artifact-1", "abcdef"), + _artifact_event("artifact-1", "abcdef"), + _artifact_event("artifact-1", "!", append=True), + ], + ) + + with mock.patch("opencode_a2a.cli.load_settings", return_value=settings): + with mock.patch("opencode_a2a.cli.A2AClient", return_value=fake_client): + assert await cli.run_call("http://agent.example.com", "hello") == 0 + + assert capsys.readouterr().out == "hello abcdef!\n" + assert fake_client.closed is True + + +@pytest.mark.asyncio +async def test_run_call_prints_failed_status_message( + capsys: pytest.CaptureFixture[str], +) -> None: + fake_client = _FakeA2AClient( + "http://agent.example.com", + settings=object(), + events=[_failed_status_event("task failed")], + ) + + with mock.patch("opencode_a2a.cli.load_settings", return_value=object()): + with mock.patch("opencode_a2a.cli.A2AClient", return_value=fake_client): + assert await cli.run_call("http://agent.example.com", "hello") == 0 + + assert "[Failed] task failed" in capsys.readouterr().out + assert fake_client.closed is True + + +@pytest.mark.asyncio +async def test_run_call_reports_errors_to_stderr(capsys: pytest.CaptureFixture[str]) -> None: + fake_client = _FakeA2AClient( + "http://agent.example.com", + settings=object(), + error=RuntimeError("boom"), + ) + + with mock.patch("opencode_a2a.cli.load_settings", return_value=object()): + with mock.patch("opencode_a2a.cli.A2AClient", return_value=fake_client): + assert await cli.run_call("http://agent.example.com", "hello") == 1 + + captured = capsys.readouterr() + assert "[Error] boom" in captured.err + assert fake_client.closed is True diff --git a/tests/server/test_transport_contract.py b/tests/server/test_transport_contract.py index ef972b8..aafdbd0 100644 --- a/tests/server/test_transport_contract.py +++ b/tests/server/test_transport_contract.py @@ -43,9 +43,9 @@ from opencode_a2a.trace_context import parse_traceparent from tests.support.helpers import ( DummyChatOpencodeUpstreamClient, - DummySessionQueryOpencodeUpstreamClient, make_basic_auth_header, ) +from tests.support.session_query_client import DummySessionQueryOpencodeUpstreamClient from tests.support.settings import make_settings diff --git a/tests/support/helpers.py b/tests/support/helpers.py index e33f832..c79f883 100644 --- a/tests/support/helpers.py +++ b/tests/support/helpers.py @@ -14,7 +14,7 @@ SESSION_BINDING_EXTENSION_URI, STREAMING_EXTENSION_URI, ) -from opencode_a2a.opencode_upstream_client import OpencodeMessage, OpencodeMessagePage +from opencode_a2a.opencode_upstream_client import OpencodeMessage from opencode_a2a.server.context_helpers import normalize_server_call_context from tests.support import settings as test_settings @@ -258,632 +258,3 @@ async def resolve_interrupt_session(self, request_id: str) -> str | None: async def discard_interrupt_request(self, request_id: str) -> None: del request_id - - -class DummySessionQueryOpencodeUpstreamClient: - def __init__( - self, - _settings: Settings, - *, - interrupt_request_repository=None, # noqa: ANN001 - ) -> None: - del interrupt_request_repository - self.settings = _settings - self.directory = _settings.opencode_workspace_root - self._sessions_payload = [{"id": "s-1", "title": "Session s-1"}] - self._session_status_payload = { - "s-1": {"type": "idle"}, - "s-2": {"type": "retry", "attempt": 2, "message": "retrying", "next": 30}, - } - self._session_payload = { - "id": "s-1", - "title": "Session s-1", - "directory": "/workspace", - "projectID": "proj-1", - } - self._child_sessions_payload = [{"id": "s-2", "title": "Child session"}] - self._todo_payload = [ - { - "id": "todo-1", - "content": "Review the diff", - "status": "pending", - "priority": "high", - } - ] - self._diff_payload = [ - { - "file": "src/app.py", - "before": "old", - "after": "new", - "additions": 3, - "deletions": 1, - } - ] - self._messages_payload = [ - { - "info": {"id": "m-1", "role": "assistant"}, - "parts": [{"type": "text", "text": "SECRET_HISTORY"}], - } - ] - self._message_payload = { - "info": {"id": "m-1", "role": "assistant"}, - "parts": [{"type": "text", "text": "One message payload"}], - } - self._reverted_session_payload = { - "id": "s-1", - "title": "Reverted session", - "directory": "/workspace", - "projectID": "proj-1", - "revert": { - "messageID": "msg-1", - "partID": "part-1", - "snapshot": "snap-1", - "diff": "diff-1", - }, - } - self._unreverted_session_payload = { - "id": "s-1", - "title": "Restored session", - "directory": "/workspace", - "projectID": "proj-1", - } - self._messages_next_cursor: str | None = None - self.last_sessions_params = None - self.last_sessions_directory: str | None = None - self.last_sessions_workspace_id: str | None = None - self.last_messages_params = None - self.last_messages_workspace_id: str | None = None - self.lifecycle_calls: list[dict[str, Any]] = [] - self.prompt_async_calls: list[dict[str, Any]] = [] - self.command_calls: list[dict[str, Any]] = [] - self.shell_calls: list[dict[str, Any]] = [] - self.workspace_control_calls: list[dict[str, Any]] = [] - self.provider_catalog_payload: dict[str, Any] = { - "all": [ - { - "id": "openai", - "name": "OpenAI", - "source": "api", - "models": { - "gpt-5": { - "name": "GPT-5", - "status": "active", - "limit": {"context": 200000, "output": 8192}, - "capabilities": { - "reasoning": True, - "toolcall": True, - "attachment": False, - }, - } - }, - }, - { - "id": "google", - "name": "Google", - "source": "config", - "models": { - "gemini-2.5-flash": { - "name": "Gemini 2.5 Flash", - "status": "beta", - "limit": {"context": 1000000, "output": 8192}, - "capabilities": { - "reasoning": True, - "toolcall": True, - "attachment": True, - }, - } - }, - }, - ], - "default": { - "openai": "gpt-5", - "google": "gemini-2.5-flash", - }, - "connected": ["openai"], - } - self._interrupt_requests: dict[str, dict[str, str | None]] = {} - self._interrupt_request_details: dict[str, dict[str, Any] | None] = {} - - async def close(self) -> None: - return None - - async def list_sessions( - self, - *, - params=None, - directory: str | None = None, - workspace_id: str | None = None, - ): - self.last_sessions_directory = directory - self.last_sessions_workspace_id = workspace_id - self.last_sessions_params = params - return self._sessions_payload - - async def list_messages(self, session_id: str, *, params=None, workspace_id: str | None = None): - assert session_id - self.last_messages_params = params - self.last_messages_workspace_id = workspace_id - return OpencodeMessagePage( - payload=self._messages_payload, - next_cursor=self._messages_next_cursor, - ) - - async def session_status( - self, - *, - directory: str | None = None, - workspace_id: str | None = None, - ): - self.lifecycle_calls.append( - { - "method": "session_status", - "directory": directory, - "workspace_id": workspace_id, - } - ) - return self._session_status_payload - - async def get_session( - self, - session_id: str, - *, - directory: str | None = None, - workspace_id: str | None = None, - ): - self.lifecycle_calls.append( - { - "method": "get_session", - "session_id": session_id, - "directory": directory, - "workspace_id": workspace_id, - } - ) - return self._session_payload - - async def list_child_sessions( - self, - session_id: str, - *, - directory: str | None = None, - workspace_id: str | None = None, - ): - self.lifecycle_calls.append( - { - "method": "list_child_sessions", - "session_id": session_id, - "directory": directory, - "workspace_id": workspace_id, - } - ) - return self._child_sessions_payload - - async def get_session_todo( - self, - session_id: str, - *, - directory: str | None = None, - workspace_id: str | None = None, - ): - self.lifecycle_calls.append( - { - "method": "get_session_todo", - "session_id": session_id, - "directory": directory, - "workspace_id": workspace_id, - } - ) - return self._todo_payload - - async def get_session_diff( - self, - session_id: str, - *, - params=None, - directory: str | None = None, - workspace_id: str | None = None, - ): - self.lifecycle_calls.append( - { - "method": "get_session_diff", - "session_id": session_id, - "params": params, - "directory": directory, - "workspace_id": workspace_id, - } - ) - return self._diff_payload - - async def get_message( - self, - session_id: str, - message_id: str, - *, - directory: str | None = None, - workspace_id: str | None = None, - ): - self.lifecycle_calls.append( - { - "method": "get_message", - "session_id": session_id, - "message_id": message_id, - "directory": directory, - "workspace_id": workspace_id, - } - ) - return self._message_payload - - async def session_prompt_async( - self, - session_id: str, - request: dict[str, Any], - *, - directory: str | None = None, - workspace_id: str | None = None, - ) -> None: - self.prompt_async_calls.append( - { - "session_id": session_id, - "request": request, - "directory": directory, - "workspace_id": workspace_id, - } - ) - - async def session_command( - self, - session_id: str, - request: dict[str, Any], - *, - directory: str | None = None, - workspace_id: str | None = None, - ) -> dict[str, Any]: - self.command_calls.append( - { - "session_id": session_id, - "request": request, - "directory": directory, - "workspace_id": workspace_id, - } - ) - return { - "info": {"id": "msg-command-1", "role": "assistant"}, - "parts": [{"type": "text", "text": "Command completed."}], - } - - async def session_shell( - self, - session_id: str, - request: dict[str, Any], - *, - directory: str | None = None, - workspace_id: str | None = None, - ) -> dict[str, Any]: - self.shell_calls.append( - { - "session_id": session_id, - "request": request, - "directory": directory, - "workspace_id": workspace_id, - } - ) - return { - "id": "msg-shell-1", - "role": "assistant", - "parts": [{"type": "text", "text": "Shell command executed."}], - } - - async def fork_session( - self, - session_id: str, - request: dict[str, Any] | None = None, - *, - directory: str | None = None, - workspace_id: str | None = None, - ): - self.lifecycle_calls.append( - { - "method": "fork_session", - "session_id": session_id, - "request": request, - "directory": directory, - "workspace_id": workspace_id, - } - ) - return { - "id": "s-2", - "title": "Forked session", - "parentID": session_id, - "directory": "/workspace", - "projectID": "proj-1", - } - - async def share_session( - self, - session_id: str, - *, - directory: str | None = None, - workspace_id: str | None = None, - ): - self.lifecycle_calls.append( - { - "method": "share_session", - "session_id": session_id, - "directory": directory, - "workspace_id": workspace_id, - } - ) - return { - "id": session_id, - "title": "Shared session", - "directory": "/workspace", - "projectID": "proj-1", - "share": {"url": "https://example.com/shared/s-1"}, - } - - async def unshare_session( - self, - session_id: str, - *, - directory: str | None = None, - workspace_id: str | None = None, - ): - self.lifecycle_calls.append( - { - "method": "unshare_session", - "session_id": session_id, - "directory": directory, - "workspace_id": workspace_id, - } - ) - return { - "id": session_id, - "title": "Unshared session", - "directory": "/workspace", - "projectID": "proj-1", - } - - async def summarize_session( - self, - session_id: str, - request: dict[str, Any] | None = None, - *, - directory: str | None = None, - workspace_id: str | None = None, - ): - self.lifecycle_calls.append( - { - "method": "summarize_session", - "session_id": session_id, - "request": request, - "directory": directory, - "workspace_id": workspace_id, - } - ) - return True - - async def revert_session( - self, - session_id: str, - request: dict[str, Any], - *, - directory: str | None = None, - workspace_id: str | None = None, - ): - self.lifecycle_calls.append( - { - "method": "revert_session", - "session_id": session_id, - "request": request, - "directory": directory, - "workspace_id": workspace_id, - } - ) - return self._reverted_session_payload - - async def unrevert_session( - self, - session_id: str, - *, - directory: str | None = None, - workspace_id: str | None = None, - ): - self.lifecycle_calls.append( - { - "method": "unrevert_session", - "session_id": session_id, - "directory": directory, - "workspace_id": workspace_id, - } - ) - return self._unreverted_session_payload - - async def list_provider_catalog( - self, - *, - directory: str | None = None, - workspace_id: str | None = None, - ): - self.workspace_control_calls.append( - { - "method": "provider_catalog", - "directory": directory, - "workspace_id": workspace_id, - } - ) - return self.provider_catalog_payload - - async def list_projects(self): - self.workspace_control_calls.append({"method": "list_projects"}) - return [{"id": "proj-1", "name": "Alpha", "directory": "/workspace"}] - - async def get_current_project(self): - self.workspace_control_calls.append({"method": "get_current_project"}) - return {"id": "proj-1", "name": "Alpha", "directory": "/workspace"} - - async def list_workspaces(self): - self.workspace_control_calls.append({"method": "list_workspaces"}) - return [{"id": "wrk-1", "type": "git", "branch": "main", "directory": None}] - - async def create_workspace(self, request: dict[str, Any]): - self.workspace_control_calls.append({"method": "create_workspace", "request": request}) - return {"id": "wrk-2", **request} - - async def remove_workspace(self, workspace_id: str): - self.workspace_control_calls.append( - {"method": "remove_workspace", "workspace_id": workspace_id} - ) - return {"id": workspace_id, "type": "git", "branch": "main", "directory": None} - - async def list_worktrees(self): - self.workspace_control_calls.append({"method": "list_worktrees"}) - return ["/tmp/worktrees/alpha"] - - async def create_worktree(self, request: dict[str, Any]): - self.workspace_control_calls.append({"method": "create_worktree", "request": request}) - return { - "name": request.get("name") or "feature-branch", - "branch": "opencode/feature-branch", - "directory": "/tmp/worktrees/feature-branch", - } - - async def remove_worktree(self, request: dict[str, Any]) -> bool: - self.workspace_control_calls.append({"method": "remove_worktree", "request": request}) - return True - - async def reset_worktree(self, request: dict[str, Any]) -> bool: - self.workspace_control_calls.append({"method": "reset_worktree", "request": request}) - return True - - async def remember_interrupt_request( - self, - *, - request_id: str, - session_id: str, - interrupt_type: str, - identity: str | None = None, - credential_id: str | None = None, - task_id: str | None = None, - context_id: str | None = None, - details: dict[str, Any] | None = None, - ttl_seconds: float | None = None, - ) -> None: - del ttl_seconds - self._interrupt_requests[request_id] = { - "session_id": session_id, - "interrupt_type": interrupt_type, - "identity": identity, - "credential_id": credential_id, - "task_id": task_id, - "context_id": context_id, - } - self._interrupt_request_details[request_id] = ( - dict(details) if isinstance(details, dict) else None - ) - - async def resolve_interrupt_request(self, request_id: str): - payload = self._interrupt_requests.get(request_id) - if payload is None: - return "missing", None - - class _Binding: - def __init__(self, data: dict[str, str | None]) -> None: - self.request_id = request_id - self.session_id = data.get("session_id") - self.interrupt_type = data.get("interrupt_type") - self.identity = data.get("identity") - self.credential_id = data.get("credential_id") - self.task_id = data.get("task_id") - self.context_id = data.get("context_id") - self.details = self_details - - self_details = self._interrupt_request_details.get(request_id) - - return "active", _Binding(payload) - - async def resolve_interrupt_session(self, request_id: str) -> str | None: - payload = self._interrupt_requests.get(request_id) - if payload is None: - return None - return payload.get("session_id") - - async def discard_interrupt_request(self, request_id: str) -> None: - self._interrupt_requests.pop(request_id, None) - self._interrupt_request_details.pop(request_id, None) - - async def list_interrupt_requests( - self, - *, - identity: str, - interrupt_type: str | None = None, - ): - class _Binding: - def __init__( - self, - *, - request_id: str, - data: dict[str, str | None], - details: dict[str, Any] | None, - ) -> None: - self.request_id = request_id - self.session_id = data.get("session_id") - self.interrupt_type = data.get("interrupt_type") - self.identity = data.get("identity") - self.credential_id = data.get("credential_id") - self.task_id = data.get("task_id") - self.context_id = data.get("context_id") - self.details = details - self.expires_at = 0.0 - - items = [] - for request_id, payload in self._interrupt_requests.items(): - if payload.get("identity") != identity: - continue - if interrupt_type is not None and payload.get("interrupt_type") != interrupt_type: - continue - items.append( - _Binding( - request_id=request_id, - data=payload, - details=self._interrupt_request_details.get(request_id), - ) - ) - return items - - async def list_permission_requests(self, *, identity: str): - return await self.list_interrupt_requests(identity=identity, interrupt_type="permission") - - async def list_question_requests(self, *, identity: str): - return await self.list_interrupt_requests(identity=identity, interrupt_type="question") - - async def permission_reply( - self, - request_id: str, - *, - reply: str, - message: str | None = None, - directory: str | None = None, - workspace_id: str | None = None, - ) -> bool: - del request_id, reply, message, directory, workspace_id - return True - - async def question_reply( - self, - request_id: str, - *, - answers: list[list[str]], - directory: str | None = None, - workspace_id: str | None = None, - ) -> bool: - del request_id, answers, directory, workspace_id - return True - - async def question_reject( - self, - request_id: str, - *, - directory: str | None = None, - workspace_id: str | None = None, - ) -> bool: - del request_id, directory, workspace_id - return True diff --git a/tests/support/interrupt_clients.py b/tests/support/interrupt_clients.py new file mode 100644 index 0000000..e5a060f --- /dev/null +++ b/tests/support/interrupt_clients.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +from typing import Any + + +class InterruptRequestClientMixin: + _interrupt_requests: dict[str, dict[str, str | None]] + _interrupt_request_details: dict[str, dict[str, Any] | None] + + async def remember_interrupt_request( + self, + *, + request_id: str, + session_id: str, + interrupt_type: str, + identity: str | None = None, + credential_id: str | None = None, + task_id: str | None = None, + context_id: str | None = None, + details: dict[str, Any] | None = None, + ttl_seconds: float | None = None, + ) -> None: + del ttl_seconds + self._interrupt_requests[request_id] = { + "session_id": session_id, + "interrupt_type": interrupt_type, + "identity": identity, + "credential_id": credential_id, + "task_id": task_id, + "context_id": context_id, + } + self._interrupt_request_details[request_id] = ( + dict(details) if isinstance(details, dict) else None + ) + + async def resolve_interrupt_request(self, request_id: str): + payload = self._interrupt_requests.get(request_id) + if payload is None: + return "missing", None + details = self._interrupt_request_details.get(request_id) + + class _Binding: + def __init__(self, data: dict[str, str | None]) -> None: + self.request_id = request_id + self.session_id = data.get("session_id") + self.interrupt_type = data.get("interrupt_type") + self.identity = data.get("identity") + self.credential_id = data.get("credential_id") + self.task_id = data.get("task_id") + self.context_id = data.get("context_id") + self.details = details + + return "active", _Binding(payload) + + async def resolve_interrupt_session(self, request_id: str) -> str | None: + payload = self._interrupt_requests.get(request_id) + if payload is None: + return None + return payload.get("session_id") + + async def discard_interrupt_request(self, request_id: str) -> None: + self._interrupt_requests.pop(request_id, None) + self._interrupt_request_details.pop(request_id, None) + + async def list_interrupt_requests( + self, + *, + identity: str, + interrupt_type: str | None = None, + ): + class _Binding: + def __init__( + self, + *, + request_id: str, + data: dict[str, str | None], + details: dict[str, Any] | None, + ) -> None: + self.request_id = request_id + self.session_id = data.get("session_id") + self.interrupt_type = data.get("interrupt_type") + self.identity = data.get("identity") + self.credential_id = data.get("credential_id") + self.task_id = data.get("task_id") + self.context_id = data.get("context_id") + self.details = details + self.expires_at = 0.0 + + items = [] + for request_id, payload in self._interrupt_requests.items(): + if payload.get("identity") != identity: + continue + if interrupt_type is not None and payload.get("interrupt_type") != interrupt_type: + continue + items.append( + _Binding( + request_id=request_id, + data=payload, + details=self._interrupt_request_details.get(request_id), + ) + ) + return items + + async def list_permission_requests(self, *, identity: str): + return await self.list_interrupt_requests(identity=identity, interrupt_type="permission") + + async def list_question_requests(self, *, identity: str): + return await self.list_interrupt_requests(identity=identity, interrupt_type="question") + + async def permission_reply( + self, + request_id: str, + *, + reply: str, + message: str | None = None, + directory: str | None = None, + workspace_id: str | None = None, + ) -> bool: + del request_id, reply, message, directory, workspace_id + return True + + async def question_reply( + self, + request_id: str, + *, + answers: list[list[str]], + directory: str | None = None, + workspace_id: str | None = None, + ) -> bool: + del request_id, answers, directory, workspace_id + return True + + async def question_reject( + self, + request_id: str, + *, + directory: str | None = None, + workspace_id: str | None = None, + ) -> bool: + del request_id, directory, workspace_id + return True diff --git a/tests/support/session_query_client.py b/tests/support/session_query_client.py new file mode 100644 index 0000000..f4447af --- /dev/null +++ b/tests/support/session_query_client.py @@ -0,0 +1,450 @@ +from __future__ import annotations + +from typing import Any + +from opencode_a2a.config import Settings +from opencode_a2a.opencode_upstream_client import OpencodeMessagePage + +from tests.support.interrupt_clients import InterruptRequestClientMixin +from tests.support.workspace_control_client import WorkspaceControlClientMixin + + +class DummySessionQueryOpencodeUpstreamClient( + WorkspaceControlClientMixin, + InterruptRequestClientMixin, +): + def __init__( + self, + _settings: Settings, + *, + interrupt_request_repository=None, # noqa: ANN001 + ) -> None: + del interrupt_request_repository + self.settings = _settings + self.directory = _settings.opencode_workspace_root + self._sessions_payload = [{"id": "s-1", "title": "Session s-1"}] + self._session_status_payload = { + "s-1": {"type": "idle"}, + "s-2": {"type": "retry", "attempt": 2, "message": "retrying", "next": 30}, + } + self._session_payload = { + "id": "s-1", + "title": "Session s-1", + "directory": "/workspace", + "projectID": "proj-1", + } + self._child_sessions_payload = [{"id": "s-2", "title": "Child session"}] + self._todo_payload = [ + { + "id": "todo-1", + "content": "Review the diff", + "status": "pending", + "priority": "high", + } + ] + self._diff_payload = [ + { + "file": "src/app.py", + "before": "old", + "after": "new", + "additions": 3, + "deletions": 1, + } + ] + self._messages_payload = [ + { + "info": {"id": "m-1", "role": "assistant"}, + "parts": [{"type": "text", "text": "SECRET_HISTORY"}], + } + ] + self._message_payload = { + "info": {"id": "m-1", "role": "assistant"}, + "parts": [{"type": "text", "text": "One message payload"}], + } + self._reverted_session_payload = { + "id": "s-1", + "title": "Reverted session", + "directory": "/workspace", + "projectID": "proj-1", + "revert": { + "messageID": "msg-1", + "partID": "part-1", + "snapshot": "snap-1", + "diff": "diff-1", + }, + } + self._unreverted_session_payload = { + "id": "s-1", + "title": "Restored session", + "directory": "/workspace", + "projectID": "proj-1", + } + self._messages_next_cursor: str | None = None + self.last_sessions_params = None + self.last_sessions_directory: str | None = None + self.last_sessions_workspace_id: str | None = None + self.last_messages_params = None + self.last_messages_workspace_id: str | None = None + self.lifecycle_calls: list[dict[str, Any]] = [] + self.prompt_async_calls: list[dict[str, Any]] = [] + self.command_calls: list[dict[str, Any]] = [] + self.shell_calls: list[dict[str, Any]] = [] + self.workspace_control_calls: list[dict[str, Any]] = [] + self.provider_catalog_payload: dict[str, Any] = { + "all": [ + { + "id": "openai", + "name": "OpenAI", + "source": "api", + "models": { + "gpt-5": { + "name": "GPT-5", + "status": "active", + "limit": {"context": 200000, "output": 8192}, + "capabilities": { + "reasoning": True, + "toolcall": True, + "attachment": False, + }, + } + }, + }, + { + "id": "google", + "name": "Google", + "source": "config", + "models": { + "gemini-2.5-flash": { + "name": "Gemini 2.5 Flash", + "status": "beta", + "limit": {"context": 1000000, "output": 8192}, + "capabilities": { + "reasoning": True, + "toolcall": True, + "attachment": True, + }, + } + }, + }, + ], + "default": { + "openai": "gpt-5", + "google": "gemini-2.5-flash", + }, + "connected": ["openai"], + } + self._interrupt_requests: dict[str, dict[str, str | None]] = {} + self._interrupt_request_details: dict[str, dict[str, Any] | None] = {} + + async def close(self) -> None: + return None + + async def list_sessions( + self, + *, + params=None, + directory: str | None = None, + workspace_id: str | None = None, + ): + self.last_sessions_directory = directory + self.last_sessions_workspace_id = workspace_id + self.last_sessions_params = params + return self._sessions_payload + + async def list_messages(self, session_id: str, *, params=None, workspace_id: str | None = None): + assert session_id + self.last_messages_params = params + self.last_messages_workspace_id = workspace_id + return OpencodeMessagePage( + payload=self._messages_payload, + next_cursor=self._messages_next_cursor, + ) + + async def session_status( + self, + *, + directory: str | None = None, + workspace_id: str | None = None, + ): + self.lifecycle_calls.append( + { + "method": "session_status", + "directory": directory, + "workspace_id": workspace_id, + } + ) + return self._session_status_payload + + async def get_session( + self, + session_id: str, + *, + directory: str | None = None, + workspace_id: str | None = None, + ): + self.lifecycle_calls.append( + { + "method": "get_session", + "session_id": session_id, + "directory": directory, + "workspace_id": workspace_id, + } + ) + return self._session_payload + + async def list_child_sessions( + self, + session_id: str, + *, + directory: str | None = None, + workspace_id: str | None = None, + ): + self.lifecycle_calls.append( + { + "method": "list_child_sessions", + "session_id": session_id, + "directory": directory, + "workspace_id": workspace_id, + } + ) + return self._child_sessions_payload + + async def get_session_todo( + self, + session_id: str, + *, + directory: str | None = None, + workspace_id: str | None = None, + ): + self.lifecycle_calls.append( + { + "method": "get_session_todo", + "session_id": session_id, + "directory": directory, + "workspace_id": workspace_id, + } + ) + return self._todo_payload + + async def get_session_diff( + self, + session_id: str, + *, + params=None, + directory: str | None = None, + workspace_id: str | None = None, + ): + self.lifecycle_calls.append( + { + "method": "get_session_diff", + "session_id": session_id, + "params": params, + "directory": directory, + "workspace_id": workspace_id, + } + ) + return self._diff_payload + + async def get_message( + self, + session_id: str, + message_id: str, + *, + directory: str | None = None, + workspace_id: str | None = None, + ): + self.lifecycle_calls.append( + { + "method": "get_message", + "session_id": session_id, + "message_id": message_id, + "directory": directory, + "workspace_id": workspace_id, + } + ) + return self._message_payload + + async def session_prompt_async( + self, + session_id: str, + request: dict[str, Any], + *, + directory: str | None = None, + workspace_id: str | None = None, + ) -> None: + self.prompt_async_calls.append( + { + "session_id": session_id, + "request": request, + "directory": directory, + "workspace_id": workspace_id, + } + ) + + async def session_command( + self, + session_id: str, + request: dict[str, Any], + *, + directory: str | None = None, + workspace_id: str | None = None, + ) -> dict[str, Any]: + self.command_calls.append( + { + "session_id": session_id, + "request": request, + "directory": directory, + "workspace_id": workspace_id, + } + ) + return { + "info": {"id": "msg-command-1", "role": "assistant"}, + "parts": [{"type": "text", "text": "Command completed."}], + } + + async def session_shell( + self, + session_id: str, + request: dict[str, Any], + *, + directory: str | None = None, + workspace_id: str | None = None, + ) -> dict[str, Any]: + self.shell_calls.append( + { + "session_id": session_id, + "request": request, + "directory": directory, + "workspace_id": workspace_id, + } + ) + return { + "id": "msg-shell-1", + "role": "assistant", + "parts": [{"type": "text", "text": "Shell command executed."}], + } + + async def fork_session( + self, + session_id: str, + request: dict[str, Any] | None = None, + *, + directory: str | None = None, + workspace_id: str | None = None, + ): + self.lifecycle_calls.append( + { + "method": "fork_session", + "session_id": session_id, + "request": request, + "directory": directory, + "workspace_id": workspace_id, + } + ) + return { + "id": "s-2", + "title": "Forked session", + "parentID": session_id, + "directory": "/workspace", + "projectID": "proj-1", + } + + async def share_session( + self, + session_id: str, + *, + directory: str | None = None, + workspace_id: str | None = None, + ): + self.lifecycle_calls.append( + { + "method": "share_session", + "session_id": session_id, + "directory": directory, + "workspace_id": workspace_id, + } + ) + return { + "id": session_id, + "title": "Shared session", + "directory": "/workspace", + "projectID": "proj-1", + "share": {"url": "https://example.com/shared/s-1"}, + } + + async def unshare_session( + self, + session_id: str, + *, + directory: str | None = None, + workspace_id: str | None = None, + ): + self.lifecycle_calls.append( + { + "method": "unshare_session", + "session_id": session_id, + "directory": directory, + "workspace_id": workspace_id, + } + ) + return { + "id": session_id, + "title": "Unshared session", + "directory": "/workspace", + "projectID": "proj-1", + } + + async def summarize_session( + self, + session_id: str, + request: dict[str, Any] | None = None, + *, + directory: str | None = None, + workspace_id: str | None = None, + ): + self.lifecycle_calls.append( + { + "method": "summarize_session", + "session_id": session_id, + "request": request, + "directory": directory, + "workspace_id": workspace_id, + } + ) + return True + + async def revert_session( + self, + session_id: str, + request: dict[str, Any], + *, + directory: str | None = None, + workspace_id: str | None = None, + ): + self.lifecycle_calls.append( + { + "method": "revert_session", + "session_id": session_id, + "request": request, + "directory": directory, + "workspace_id": workspace_id, + } + ) + return self._reverted_session_payload + + async def unrevert_session( + self, + session_id: str, + *, + directory: str | None = None, + workspace_id: str | None = None, + ): + self.lifecycle_calls.append( + { + "method": "unrevert_session", + "session_id": session_id, + "directory": directory, + "workspace_id": workspace_id, + } + ) + return self._unreverted_session_payload diff --git a/tests/support/workspace_control_client.py b/tests/support/workspace_control_client.py new file mode 100644 index 0000000..d54b8c2 --- /dev/null +++ b/tests/support/workspace_control_client.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing import Any + + +class WorkspaceControlClientMixin: + workspace_control_calls: list[dict[str, Any]] + provider_catalog_payload: dict[str, Any] + + async def list_provider_catalog( + self, + *, + directory: str | None = None, + workspace_id: str | None = None, + ): + self.workspace_control_calls.append( + { + "method": "provider_catalog", + "directory": directory, + "workspace_id": workspace_id, + } + ) + return self.provider_catalog_payload + + async def list_projects(self): + self.workspace_control_calls.append({"method": "list_projects"}) + return [{"id": "proj-1", "name": "Alpha", "directory": "/workspace"}] + + async def get_current_project(self): + self.workspace_control_calls.append({"method": "get_current_project"}) + return {"id": "proj-1", "name": "Alpha", "directory": "/workspace"} + + async def list_workspaces(self): + self.workspace_control_calls.append({"method": "list_workspaces"}) + return [{"id": "wrk-1", "type": "git", "branch": "main", "directory": None}] + + async def create_workspace(self, request: dict[str, Any]): + self.workspace_control_calls.append({"method": "create_workspace", "request": request}) + return {"id": "wrk-2", **request} + + async def remove_workspace(self, workspace_id: str): + self.workspace_control_calls.append( + {"method": "remove_workspace", "workspace_id": workspace_id} + ) + return {"id": workspace_id, "type": "git", "branch": "main", "directory": None} + + async def list_worktrees(self): + self.workspace_control_calls.append({"method": "list_worktrees"}) + return ["/tmp/worktrees/alpha"] + + async def create_worktree(self, request: dict[str, Any]): + self.workspace_control_calls.append({"method": "create_worktree", "request": request}) + return { + "name": request.get("name") or "feature-branch", + "branch": "opencode/feature-branch", + "directory": "/tmp/worktrees/feature-branch", + } + + async def remove_worktree(self, request: dict[str, Any]) -> bool: + self.workspace_control_calls.append({"method": "remove_worktree", "request": request}) + return True + + async def reset_worktree(self, request: dict[str, Any]) -> bool: + self.workspace_control_calls.append({"method": "reset_worktree", "request": request}) + return True diff --git a/tests/test_parsing.py b/tests/test_parsing.py new file mode 100644 index 0000000..2bfe2b0 --- /dev/null +++ b/tests/test_parsing.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +import pytest + +from opencode_a2a.parsing import ( + parse_bool_field, + parse_int_field, + parse_string_field, + parse_timestamp_field, +) + + +def _error_factory(field: str, message: str) -> ValueError: + return ValueError(f"{field}: {message}") + + +@pytest.mark.parametrize( + ("value", "minimum", "expected"), + [ + pytest.param(None, None, None, id="none"), + pytest.param(7, None, 7, id="int"), + pytest.param("12", None, 12, id="string-int"), + pytest.param("0", 0, 0, id="string-zero"), + ], +) +def test_parse_int_field_accepts_supported_values( + value: object, + minimum: int | None, + expected: int | None, +) -> None: + assert ( + parse_int_field( + value, + field="limit", + error_factory=_error_factory, + minimum=minimum, + ) + == expected + ) + + +@pytest.mark.parametrize( + ("value", "minimum", "message"), + [ + pytest.param(True, None, "limit must be an integer", id="bool"), + pytest.param("abc", None, "limit must be an integer", id="non-numeric-string"), + pytest.param(1.5, None, "limit must be an integer", id="float"), + pytest.param(-1, 0, "limit must be >= 0", id="below-minimum"), + ], +) +def test_parse_int_field_rejects_invalid_values( + value: object, + minimum: int | None, + message: str, +) -> None: + with pytest.raises(ValueError, match=message): + parse_int_field( + value, + field="limit", + error_factory=_error_factory, + minimum=minimum, + ) + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + pytest.param(None, None, id="none"), + pytest.param(" hello ", "hello", id="trim"), + pytest.param(" ", None, id="blank"), + ], +) +def test_parse_string_field_normalizes_whitespace(value: object, expected: str | None) -> None: + assert parse_string_field(value, field="cursor", error_factory=_error_factory) == expected + + +def test_parse_string_field_rejects_non_strings() -> None: + with pytest.raises(ValueError, match="cursor must be a string"): + parse_string_field(42, field="cursor", error_factory=_error_factory) + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + pytest.param(None, None, id="none"), + pytest.param(True, True, id="bool"), + pytest.param(" YES ", True, id="string-true"), + pytest.param("off", False, id="string-false"), + ], +) +def test_parse_bool_field_accepts_supported_values(value: object, expected: bool | None) -> None: + assert parse_bool_field(value, field="roots", error_factory=_error_factory) is expected + + +@pytest.mark.parametrize("value", [1, "maybe", object()]) +def test_parse_bool_field_rejects_invalid_values(value: object) -> None: + with pytest.raises(ValueError, match="roots must be a boolean"): + parse_bool_field(value, field="roots", error_factory=_error_factory) + + +def test_parse_timestamp_field_supports_z_suffix() -> None: + parsed = parse_timestamp_field( + "2025-01-02T03:04:05Z", + field="timestamp", + error_factory=_error_factory, + ) + + assert parsed == datetime(2025, 1, 2, 3, 4, 5, tzinfo=UTC) + + +def test_parse_timestamp_field_promotes_naive_values_to_utc() -> None: + parsed = parse_timestamp_field( + "2025-01-02T03:04:05", + field="timestamp", + error_factory=_error_factory, + ) + + assert parsed == datetime(2025, 1, 2, 3, 4, 5, tzinfo=UTC) + + +@pytest.mark.parametrize("value", [123, "not-a-timestamp"]) +def test_parse_timestamp_field_rejects_invalid_values(value: object) -> None: + with pytest.raises(ValueError, match="timestamp must be a valid ISO 8601 timestamp"): + parse_timestamp_field(value, field="timestamp", error_factory=_error_factory) From bdcfd5c8767cd97f6be896ba4d81d598186a3d97 Mon Sep 17 00:00:00 2001 From: "helen@cloud" Date: Wed, 6 May 2026 20:40:11 +0800 Subject: [PATCH 2/2] refactor: inline jsonrpc params alias parsing --- src/opencode_a2a/jsonrpc/params.py | 49 +++------------------------ tests/support/session_query_client.py | 1 - 2 files changed, 5 insertions(+), 45 deletions(-) diff --git a/src/opencode_a2a/jsonrpc/params.py b/src/opencode_a2a/jsonrpc/params.py index 8dfd7b7..160ea83 100644 --- a/src/opencode_a2a/jsonrpc/params.py +++ b/src/opencode_a2a/jsonrpc/params.py @@ -94,56 +94,23 @@ def _normalize_session_query_limit( return {"limit": normalized_limit} -def _normalize_alias_field( - *, - params: dict[str, Any], - field: str, - parser, -) -> Any: - return parser(params.get(field), field=field) - - def parse_list_sessions_params(params: dict[str, Any]) -> dict[str, Any]: _reject_nested_query_params(params) _validate_pagination_fields(params) normalized_query = _normalize_session_query_limit(limit=params.get("limit")) - directory = _normalize_alias_field( - params=params, - field="directory", - parser=_parse_string_field, - ) - roots = _normalize_alias_field( - params=params, - field="roots", - parser=_parse_bool_field, - ) - start = _normalize_alias_field( - params=params, - field="start", - parser=_parse_non_negative_int, - ) - search = _normalize_alias_field( - params=params, - field="search", - parser=_parse_string_field, - ) + directory = _parse_string_field(params.get("directory"), field="directory") + roots = _parse_bool_field(params.get("roots"), field="roots") + start = _parse_non_negative_int(params.get("start"), field="start") + search = _parse_string_field(params.get("search"), field="search") if directory is not None: normalized_query["directory"] = directory - else: - normalized_query.pop("directory", None) if roots is not None: normalized_query["roots"] = roots - else: - normalized_query.pop("roots", None) if start is not None: normalized_query["start"] = start - else: - normalized_query.pop("start", None) if search is not None: normalized_query["search"] = search - else: - normalized_query.pop("search", None) return normalized_query @@ -158,13 +125,7 @@ def parse_get_session_messages_params(params: dict[str, Any]) -> tuple[str, dict _reject_nested_query_params(params) _validate_pagination_fields(params) normalized_query = _normalize_session_query_limit(limit=params.get("limit")) - before = _normalize_alias_field( - params=params, - field="before", - parser=_parse_string_field, - ) + before = _parse_string_field(params.get("before"), field="before") if before is not None: normalized_query["before"] = before - else: - normalized_query.pop("before", None) return raw_session_id.strip(), normalized_query diff --git a/tests/support/session_query_client.py b/tests/support/session_query_client.py index f4447af..28197d2 100644 --- a/tests/support/session_query_client.py +++ b/tests/support/session_query_client.py @@ -4,7 +4,6 @@ from opencode_a2a.config import Settings from opencode_a2a.opencode_upstream_client import OpencodeMessagePage - from tests.support.interrupt_clients import InterruptRequestClientMixin from tests.support.workspace_control_client import WorkspaceControlClientMixin