From fcc661adaf11d624dd7b0190cdb018edd113fd22 Mon Sep 17 00:00:00 2001 From: Diogo Andre Santos Date: Sat, 11 Apr 2026 12:30:20 +0100 Subject: [PATCH 1/3] feat: add built-in MCPDriver with stdio and Streamable HTTP transports MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #41, #53, #54. - New MCPDriver class (src/agent_kernel/drivers/mcp.py, ~180 LOC) plus mcp_support.py helpers (~140 LOC) kept within the ≤300-line module rule. - from_stdio() and from_http() class methods; initialize() is called inside each transport factory, not in the hot-path _run_with_retry(). - discover() converts tools/list response into Capability objects with configurable namespace and safety_class_map. - execute() strips the 'operation' key, merges ctx.constraints as defaults, calls tools/call, and normalises MCP content blocks / structuredContent into plain Python data for the firewall. - isError responses raise DriverError with the server-provided detail. - import_optional() raises a helpful ImportError if mcp is not installed. - MCPDriver exported from agent_kernel top-level __init__.py. - mcp>=1.0 added to [dev] extras so CI tests the real SDK path. - 9 tests: 8 stub-based unit tests + 1 real FastMCP in-process integration test using create_connected_server_and_client_session. - docs/integrations.md updated with real stdio + HTTP usage examples. - CHANGELOG.md Unreleased section updated. --- CHANGELOG.md | 3 + docs/integrations.md | 62 +++-- pyproject.toml | 1 + src/agent_kernel/__init__.py | 2 + src/agent_kernel/drivers/__init__.py | 3 +- src/agent_kernel/drivers/mcp.py | 179 ++++++++++++++ src/agent_kernel/drivers/mcp_support.py | 138 +++++++++++ tests/test_mcp_driver.py | 295 ++++++++++++++++++++++++ 8 files changed, 664 insertions(+), 19 deletions(-) create mode 100644 src/agent_kernel/drivers/mcp.py create mode 100644 src/agent_kernel/drivers/mcp_support.py create mode 100644 tests/test_mcp_driver.py diff --git a/CHANGELOG.md b/CHANGELOG.md index cd4a99f..bc0f83e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Built-in `MCPDriver` with stdio and Streamable HTTP transports, tool auto-discovery, normalized MCP result handling, and optional dependency guardrails. + ## [0.4.0] - 2026-03-14 ### Added diff --git a/docs/integrations.md b/docs/integrations.md index 4a3533b..af4230f 100644 --- a/docs/integrations.md +++ b/docs/integrations.md @@ -2,34 +2,60 @@ ## MCP (Model Context Protocol) -To integrate with an MCP server, implement a custom driver that wraps the MCP client: +The built-in `MCPDriver` supports both local stdio servers and remote Streamable HTTP servers. -```python -from agent_kernel.drivers.base import Driver, ExecutionContext -from agent_kernel.models import RawResult +Install the optional dependency first: -class MCPDriver: - def __init__(self, mcp_client, driver_id: str = "mcp"): - self._client = mcp_client - self._driver_id = driver_id +```bash +pip install "weaver-kernel[mcp]" +``` - @property - def driver_id(self) -> str: - return self._driver_id +### Stdio transport + +```python +from agent_kernel import CapabilityRegistry, Kernel, StaticRouter +from agent_kernel.drivers.mcp import MCPDriver + +registry = CapabilityRegistry() +router = StaticRouter(fallback=[]) +kernel = Kernel(registry=registry, router=router) + +# Connect to a local MCP server process. +driver = MCPDriver.from_stdio( + command="python", + args=["-m", "my_mcp_server"], + server_name="local-tools", +) +kernel.register_driver(driver) + +# Discover tools and register them as capabilities. +capabilities = await driver.discover(namespace="local") +registry.register_many(capabilities) - async def execute(self, ctx: ExecutionContext) -> RawResult: - operation = ctx.args.get("operation", ctx.capability_id) - result = await self._client.call_tool(operation, ctx.args) - return RawResult(capability_id=ctx.capability_id, data=result) +# Route each discovered capability to this MCP driver. +for capability in capabilities: + router.add_route(capability.capability_id, [driver.driver_id]) ``` -Then register it: +### Streamable HTTP transport ```python -kernel.register_driver(MCPDriver(mcp_client)) -router.add_route("mcp.my_tool", ["mcp"]) +from agent_kernel.drivers.mcp import MCPDriver + +http_driver = MCPDriver.from_http( + url="https://example.com/mcp", + server_name="remote-tools", + max_retries=1, +) ``` +### Notes + +- `discover()` converts `tools/list` results into `Capability` objects. +- `execute()` calls `tools/call` and normalizes MCP content blocks for the firewall. +- MCP `isError` responses raise `DriverError` with the server-provided detail. +- If `mcp` is not installed, factory methods raise a helpful `ImportError`. + ## HTTPDriver The built-in `HTTPDriver` supports GET, POST, PUT, DELETE: diff --git a/pyproject.toml b/pyproject.toml index c302c07..39f8bf9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dev = [ "ruff>=0.4", "mypy>=1.10", "httpx>=0.27", + "mcp>=1.0", ] mcp = ["mcp>=1.0"] otel = ["opentelemetry-api>=1.20"] diff --git a/src/agent_kernel/__init__.py b/src/agent_kernel/__init__.py index e9da214..35cfb78 100644 --- a/src/agent_kernel/__init__.py +++ b/src/agent_kernel/__init__.py @@ -37,6 +37,7 @@ from .drivers.base import Driver, ExecutionContext from .drivers.http import HTTPDriver +from .drivers.mcp import MCPDriver from .drivers.memory import InMemoryDriver, make_billing_driver from .enums import SafetyClass, SensitivityTag from .errors import ( @@ -129,6 +130,7 @@ "ExecutionContext", "InMemoryDriver", "HTTPDriver", + "MCPDriver", "make_billing_driver", # firewall "Firewall", diff --git a/src/agent_kernel/drivers/__init__.py b/src/agent_kernel/drivers/__init__.py index 6163c73..4d2de5c 100644 --- a/src/agent_kernel/drivers/__init__.py +++ b/src/agent_kernel/drivers/__init__.py @@ -2,6 +2,7 @@ from .base import Driver, ExecutionContext from .http import HTTPDriver +from .mcp import MCPDriver from .memory import InMemoryDriver -__all__ = ["Driver", "ExecutionContext", "HTTPDriver", "InMemoryDriver"] +__all__ = ["Driver", "ExecutionContext", "HTTPDriver", "MCPDriver", "InMemoryDriver"] diff --git a/src/agent_kernel/drivers/mcp.py b/src/agent_kernel/drivers/mcp.py new file mode 100644 index 0000000..e84b23c --- /dev/null +++ b/src/agent_kernel/drivers/mcp.py @@ -0,0 +1,179 @@ +"""MCP driver: execute capabilities against Model Context Protocol servers.""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Any + +from ..enums import SafetyClass +from ..errors import DriverError +from ..models import Capability, ImplementationRef, RawResult +from .base import ExecutionContext +from .mcp_support import ( + SessionFactory, + build_http_session_factory, + build_stdio_session_factory, + call_tool, + extract_tool_specs, + normalize_call_result, +) + + +class MCPDriver: + """A driver that invokes capabilities via MCP tools/call.""" + + def __init__( + self, + *, + driver_id: str, + session_factory: SessionFactory, + server_name: str, + transport: str, + max_http_retries: int = 1, + ) -> None: + self._driver_id = driver_id + self._session_factory = session_factory + self._server_name = server_name + self._transport = transport + self._max_http_retries = max(max_http_retries, 0) + + @property + def driver_id(self) -> str: + """Unique identifier for this driver instance.""" + return self._driver_id + + @classmethod + def from_stdio( + cls, + command: str, + args: list[str] | None = None, + *, + server_name: str = "stdio", + ) -> MCPDriver: + """Create an MCP driver using stdio transport. + + Raises: + ImportError: If the optional ``mcp`` dependency is not installed. + """ + session_factory = build_stdio_session_factory(command=command, args=args or []) + return cls( + driver_id=f"mcp:{server_name}", + session_factory=session_factory, + server_name=server_name, + transport="stdio", + max_http_retries=0, + ) + + @classmethod + def from_http( + cls, + url: str, + *, + server_name: str = "http", + max_retries: int = 1, + ) -> MCPDriver: + """Create an MCP driver using Streamable HTTP transport. + + Raises: + ImportError: If the optional ``mcp`` dependency is not installed. + """ + session_factory = build_http_session_factory(url=url) + return cls( + driver_id=f"mcp:{server_name}", + session_factory=session_factory, + server_name=server_name, + transport="http", + max_http_retries=max_retries, + ) + + async def discover( + self, + *, + namespace: str | None = None, + safety_class_map: dict[str, SafetyClass] | None = None, + ) -> list[Capability]: + """Discover MCP tools and convert them to capabilities.""" + tool_list = await self._run_with_retry( + operation_name="tools/list", + action=lambda session: session.list_tools(), + ) + + capabilities: list[Capability] = [] + for spec in extract_tool_specs(tool_list): + capability_id = f"{namespace}.{spec.name}" if namespace else spec.name + safety_class = ( + safety_class_map.get(spec.name, SafetyClass.READ) + if safety_class_map is not None + else SafetyClass.READ + ) + capabilities.append( + Capability( + capability_id=capability_id, + name=spec.name, + description=spec.description, + safety_class=safety_class, + tags=["mcp", self._server_name], + impl=ImplementationRef( + driver_id=self._driver_id, + operation=spec.name, + ), + ) + ) + return capabilities + + async def execute(self, ctx: ExecutionContext) -> RawResult: + """Execute an MCP tool call for the given capability context.""" + operation = str(ctx.args.get("operation", ctx.capability_id)) + params = {k: v for k, v in ctx.args.items() if k != "operation"} + + # Apply policy constraints as default arguments, without overriding explicit args. + for key, value in ctx.constraints.items(): + params.setdefault(key, value) + + result = await self._run_with_retry( + operation_name=f"tools/call:{operation}", + action=lambda session: call_tool( + session, + operation=operation, + params=params, + ), + ) + + data = normalize_call_result( + result, + operation=operation, + driver_id=self._driver_id, + ) + return RawResult( + capability_id=ctx.capability_id, + data=data, + metadata={ + "driver_id": self._driver_id, + "transport": self._transport, + "operation": operation, + }, + ) + + async def _run_with_retry( + self, + *, + operation_name: str, + action: Callable[[Any], Awaitable[Any]], + ) -> Any: + attempts = 1 + self._max_http_retries if self._transport == "http" else 1 + last_exc: Exception | None = None + + for _attempt in range(attempts): + try: + async with self._session_factory() as session: + return await action(session) + except DriverError: + raise + except Exception as exc: + last_exc = exc + + reason = str(last_exc) if last_exc is not None else "unknown transport failure" + raise DriverError( + f"MCPDriver '{self._driver_id}' failed during {operation_name} over " + f"{self._transport}: {reason}" + ) from last_exc diff --git a/src/agent_kernel/drivers/mcp_support.py b/src/agent_kernel/drivers/mcp_support.py new file mode 100644 index 0000000..ba2d7bc --- /dev/null +++ b/src/agent_kernel/drivers/mcp_support.py @@ -0,0 +1,138 @@ +"""Internal helpers for MCP transport wiring and result normalization.""" + +from __future__ import annotations + +import importlib +from collections.abc import AsyncIterator, Callable +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from dataclasses import dataclass +from typing import Any + +from ..errors import DriverError + +SessionFactory = Callable[[], AbstractAsyncContextManager[Any]] + + +@dataclass(slots=True) +class ToolSpec: + """Normalized MCP tool metadata for capability conversion.""" + + name: str + description: str + + +async def call_tool(session: Any, *, operation: str, params: dict[str, Any]) -> Any: + """Call an MCP tool via tools/call.""" + return await session.call_tool(operation, arguments=params) + + +def extract_tool_specs(tool_list_response: Any) -> list[ToolSpec]: + """Extract tool metadata from a tools/list response payload.""" + tools = getattr(tool_list_response, "tools", []) + if not isinstance(tools, list): + return [] + specs: list[ToolSpec] = [] + for tool in tools: + name = getattr(tool, "name", None) + if not isinstance(name, str) or not name: + continue + specs.append( + ToolSpec( + name=name, + description=str(getattr(tool, "description", "") or ""), + ) + ) + return specs + + +def normalize_call_result(result: Any, *, operation: str, driver_id: str) -> Any: + """Normalize an MCP CallToolResult into plain Python data.""" + is_error = bool(getattr(result, "isError", False)) + content = [_normalize_content_item(c) for c in (getattr(result, "content", None) or [])] + + if is_error: + detail = next( + (b["text"] for b in content if b.get("type") == "text" and b.get("text", "").strip()), + "MCP server returned isError=true", + ) + raise DriverError( + f"MCPDriver '{driver_id}' tool '{operation}' returned an error: {detail}" + ) + + structured: dict[str, Any] | None = getattr(result, "structuredContent", None) + if structured is not None and content: + return {"structured_content": structured, "content": content} + if structured is not None: + return structured + return content if content else {} + + +def import_optional(module_name: str) -> Any: + """Import optional MCP SDK module with a consistent guidance message.""" + try: + return importlib.import_module(module_name) + except ModuleNotFoundError as exc: + raise ImportError( + "MCP support requires the optional dependency 'mcp>=1.0'. " + "Install it with: pip install 'weaver-kernel[mcp]'" + ) from exc + + +def build_stdio_session_factory(*, command: str, args: list[str]) -> SessionFactory: + """Build a stdio-backed MCP session factory.""" + stdio_mod = import_optional("mcp.client.stdio") + session_mod = import_optional("mcp.client.session") + + stdio_client = stdio_mod.stdio_client + server_params_cls = stdio_mod.StdioServerParameters + session_cls = session_mod.ClientSession + + @asynccontextmanager + async def factory() -> AsyncIterator[Any]: + params = server_params_cls(command=command, args=args) + async with stdio_client(params) as streams: + read_stream, write_stream = streams + async with session_cls(read_stream, write_stream) as session: + await session.initialize() + yield session + + return factory + + +def build_http_session_factory(*, url: str) -> SessionFactory: + """Build a Streamable HTTP-backed MCP session factory.""" + streamable_mod = import_optional("mcp.client.streamable_http") + session_mod = import_optional("mcp.client.session") + + streamable_http_client = streamable_mod.streamable_http_client + session_cls = session_mod.ClientSession + + @asynccontextmanager + async def factory() -> AsyncIterator[Any]: + async with streamable_http_client(url) as streams: + read_stream, write_stream = streams + async with session_cls(read_stream, write_stream) as session: + await session.initialize() + yield session + + return factory + + +def _normalize_content_item(item: Any) -> dict[str, Any]: + item_type = str(getattr(item, "type", "")).lower() + if item_type == "text": + return {"type": "text", "text": str(getattr(item, "text", ""))} + if item_type == "image": + return { + "type": "image", + "data": getattr(item, "data", None), + "mime_type": getattr(item, "mimeType", None), + } + if item_type in {"resource", "resourcelink"}: + resource = getattr(item, "resource", item) + return { + "type": "resource", + "resource": resource.model_dump() if hasattr(resource, "model_dump") else resource, + } + # AudioContent or any future type - fall back to model_dump + return item.model_dump() if hasattr(item, "model_dump") else {"type": "value", "value": item} diff --git a/tests/test_mcp_driver.py b/tests/test_mcp_driver.py new file mode 100644 index 0000000..5488950 --- /dev/null +++ b/tests/test_mcp_driver.py @@ -0,0 +1,295 @@ +"""Tests for the built-in MCPDriver.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any +from unittest.mock import patch + +import pytest +from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool + +from agent_kernel import ( + CapabilityRegistry, + CapabilityRequest, + DriverError, + Kernel, + MCPDriver, + Principal, + SafetyClass, + StaticRouter, +) + + +class _FakeSession: + """Small async session stub for MCPDriver tests.""" + + def __init__( + self, + *, + tools: list[Tool] | None = None, + call_result: CallToolResult | None = None, + call_error: Exception | None = None, + ) -> None: + self._tools = tools or [] + self._call_result = call_result or CallToolResult(content=[]) + self._call_error = call_error + self.calls: list[tuple[str, dict[str, Any]]] = [] + + async def list_tools(self) -> ListToolsResult: + return ListToolsResult(tools=self._tools) + + async def call_tool( + self, + operation: str, + arguments: dict[str, Any], + ) -> CallToolResult: + self.calls.append((operation, arguments)) + if self._call_error is not None: + raise self._call_error + return self._call_result + + +def _reusable_factory(session: _FakeSession) -> Any: + @asynccontextmanager + async def _factory() -> AsyncIterator[_FakeSession]: + yield session + + return _factory + + +def _sequence_factory(sessions: list[_FakeSession]) -> Any: + @asynccontextmanager + async def _factory() -> AsyncIterator[_FakeSession]: + if not sessions: + raise RuntimeError("no fake sessions left") + yield sessions.pop(0) + + return _factory + + +def test_from_stdio_missing_dependency_raises_helpful_import_error() -> None: + with patch("agent_kernel.drivers.mcp_support.importlib.import_module") as import_module: + import_module.side_effect = ModuleNotFoundError("No module named 'mcp'") + with pytest.raises(ImportError, match=r"weaver-kernel\[mcp\]"): + MCPDriver.from_stdio("python", ["server.py"]) + + +def test_from_http_missing_dependency_raises_helpful_import_error() -> None: + with patch("agent_kernel.drivers.mcp_support.importlib.import_module") as import_module: + import_module.side_effect = ModuleNotFoundError("No module named 'mcp'") + with pytest.raises(ImportError, match=r"weaver-kernel\[mcp\]"): + MCPDriver.from_http("http://localhost:8080/mcp") + + +@pytest.mark.asyncio +async def test_discover_converts_tools_to_capabilities() -> None: + session = _FakeSession( + tools=[ + Tool(name="list_files", description="List files", inputSchema={}), + Tool(name="write_file", description="Write file", inputSchema={}), + ] + ) + driver = MCPDriver( + driver_id="mcp:local", + session_factory=_reusable_factory(session), + server_name="local", + transport="stdio", + ) + + capabilities = await driver.discover( + namespace="fs", safety_class_map={"write_file": SafetyClass.WRITE} + ) + + assert [cap.capability_id for cap in capabilities] == [ + "fs.list_files", + "fs.write_file", + ] + assert capabilities[0].safety_class == SafetyClass.READ + assert capabilities[1].safety_class == SafetyClass.WRITE + assert capabilities[0].impl is not None + assert capabilities[0].impl.driver_id == "mcp:local" + assert capabilities[0].impl.operation == "list_files" + + +@pytest.mark.asyncio +async def test_execute_calls_tool_and_applies_constraints_defaults() -> None: + session = _FakeSession( + call_result=CallToolResult(content=[TextContent(type="text", text="ok")]) + ) + driver = MCPDriver( + driver_id="mcp:local", + session_factory=_reusable_factory(session), + server_name="local", + transport="stdio", + ) + + from agent_kernel.drivers.base import ExecutionContext + + ctx = ExecutionContext( + capability_id="fs.list_files", + principal_id="u1", + args={"operation": "list_files", "path": "/tmp", "max_rows": 5}, + constraints={"max_rows": 2, "allowed_fields": ["name"]}, + ) + + result = await driver.execute(ctx) + + assert result.capability_id == "fs.list_files" + assert result.data == [{"type": "text", "text": "ok"}] + assert session.calls[0][0] == "list_files" + # Explicit args are preserved; missing constraints are merged in. + assert session.calls[0][1]["max_rows"] == 5 + assert session.calls[0][1]["allowed_fields"] == ["name"] + + +@pytest.mark.asyncio +async def test_execute_prefers_structured_content_when_available() -> None: + session = _FakeSession( + call_result=CallToolResult( + structuredContent={"total": 3}, + content=[TextContent(type="text", text="computed")], + ) + ) + driver = MCPDriver( + driver_id="mcp:local", + session_factory=_reusable_factory(session), + server_name="local", + transport="stdio", + ) + + from agent_kernel.drivers.base import ExecutionContext + + ctx = ExecutionContext(capability_id="math.sum", principal_id="u1") + result = await driver.execute(ctx) + + assert result.data == { + "structured_content": {"total": 3}, + "content": [{"type": "text", "text": "computed"}], + } + + +@pytest.mark.asyncio +async def test_execute_raises_driver_error_on_mcp_is_error() -> None: + session = _FakeSession( + call_result=CallToolResult( + isError=True, + content=[TextContent(type="text", text="permission denied")], + ) + ) + driver = MCPDriver( + driver_id="mcp:local", + session_factory=_reusable_factory(session), + server_name="local", + transport="stdio", + ) + + from agent_kernel.drivers.base import ExecutionContext + + ctx = ExecutionContext(capability_id="secrets.read", principal_id="u1") + with pytest.raises(DriverError, match="permission denied"): + await driver.execute(ctx) + + +@pytest.mark.asyncio +async def test_http_transport_retries_after_connection_drop() -> None: + first = _FakeSession(call_error=RuntimeError("connection dropped")) + second = _FakeSession( + call_result=CallToolResult(content=[TextContent(type="text", text="ok")]) + ) + + driver = MCPDriver( + driver_id="mcp:http", + session_factory=_sequence_factory([first, second]), + server_name="remote", + transport="http", + max_http_retries=1, + ) + + from agent_kernel.drivers.base import ExecutionContext + + ctx = ExecutionContext(capability_id="echo", principal_id="u1") + result = await driver.execute(ctx) + + assert result.data == [{"type": "text", "text": "ok"}] + assert len(first.calls) == 1 + assert len(second.calls) == 1 + + +@pytest.mark.asyncio +async def test_kernel_pipeline_with_discover_register_grant_invoke() -> None: + session = _FakeSession( + tools=[Tool(name="math.sum", description="Sum two values", inputSchema={})], + call_result=CallToolResult(structuredContent={"total": 3}, content=[]), + ) + driver = MCPDriver( + driver_id="mcp:demo", + session_factory=_reusable_factory(session), + server_name="demo", + transport="stdio", + ) + + capabilities = await driver.discover() + registry = CapabilityRegistry() + registry.register_many(capabilities) + + router = StaticRouter(routes={"math.sum": ["mcp:demo"]}, fallback=[]) + kernel = Kernel(registry=registry, router=router) + kernel.register_driver(driver) + + principal = Principal(principal_id="u1", roles=["reader"]) + request = CapabilityRequest(capability_id="math.sum", goal="sum numbers") + + token = kernel.get_token(request, principal, justification="") + frame = await kernel.invoke( + token, + principal=principal, + args={"operation": "math.sum", "a": 1, "b": 2}, + ) + + assert frame.response_mode == "summary" + assert any("total" in fact.lower() for fact in frame.facts) + + +@pytest.mark.asyncio +async def test_real_fastmcp_in_process_discover_and_execute() -> None: + """Full discover→execute cycle driven by a real FastMCP server in-process.""" + from mcp.client.session import ClientSession + from mcp.server.fastmcp import FastMCP + from mcp.shared.memory import create_connected_server_and_client_session + + mcp_srv = FastMCP("math") + + @mcp_srv.tool() + def add(a: int, b: int) -> int: + """Add two integers.""" + return a + b + + @asynccontextmanager + async def in_memory_factory() -> AsyncIterator[ClientSession]: + async with create_connected_server_and_client_session(mcp_srv) as session: + yield session + + driver = MCPDriver( + driver_id="mcp:math", + session_factory=in_memory_factory, + server_name="math", + transport="stdio", + ) + + capabilities = await driver.discover(namespace="math") + assert any(cap.capability_id == "math.add" for cap in capabilities) + add_cap = next(c for c in capabilities if c.capability_id == "math.add") + assert add_cap.impl is not None + assert add_cap.impl.operation == "add" + + from agent_kernel.drivers.base import ExecutionContext + + ctx = ExecutionContext( + capability_id="math.add", + principal_id="u1", + args={"operation": "add", "a": 3, "b": 4}, + ) + result = await driver.execute(ctx) + assert result.data is not None From a03f6310a4186d68d7d85d3c935306f6c96e2949 Mon Sep 17 00:00:00 2001 From: Diogo Andre Santos Date: Sun, 12 Apr 2026 11:58:53 +0100 Subject: [PATCH 2/3] fix: address review comments on MCPDriver - _run_with_retry: add inline comment explaining broad exception catch is intentional; MCP tool errors arrive as isError=True responses, not Python exceptions, so exceptions at session-factory level are transport failures (connection refused, EOF, timeout) - normalize_call_result: return content (always list) instead of {} on empty content path; eliminates type inconsistency and gives a more informative firewall summary (''List of 0 items'' vs ''Keys: '') - docs/integrations.md: wrap stdio example in async def main() / asyncio.run(main()) so the snippet is copy-paste runnable, consistent with all three repo examples --- docs/integrations.md | 39 +++++++++++++++---------- src/agent_kernel/drivers/mcp.py | 5 ++++ src/agent_kernel/drivers/mcp_support.py | 2 +- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/docs/integrations.md b/docs/integrations.md index af4230f..73809e7 100644 --- a/docs/integrations.md +++ b/docs/integrations.md @@ -13,28 +13,35 @@ pip install "weaver-kernel[mcp]" ### Stdio transport ```python +import asyncio + from agent_kernel import CapabilityRegistry, Kernel, StaticRouter from agent_kernel.drivers.mcp import MCPDriver -registry = CapabilityRegistry() -router = StaticRouter(fallback=[]) -kernel = Kernel(registry=registry, router=router) -# Connect to a local MCP server process. -driver = MCPDriver.from_stdio( - command="python", - args=["-m", "my_mcp_server"], - server_name="local-tools", -) -kernel.register_driver(driver) +async def main() -> None: + registry = CapabilityRegistry() + router = StaticRouter(fallback=[]) + kernel = Kernel(registry=registry, router=router) + + # Connect to a local MCP server process. + driver = MCPDriver.from_stdio( + command="python", + args=["-m", "my_mcp_server"], + server_name="local-tools", + ) + kernel.register_driver(driver) + + # Discover tools and register them as capabilities. + capabilities = await driver.discover(namespace="local") + registry.register_many(capabilities) + + # Route each discovered capability to this MCP driver. + for capability in capabilities: + router.add_route(capability.capability_id, [driver.driver_id]) -# Discover tools and register them as capabilities. -capabilities = await driver.discover(namespace="local") -registry.register_many(capabilities) -# Route each discovered capability to this MCP driver. -for capability in capabilities: - router.add_route(capability.capability_id, [driver.driver_id]) +asyncio.run(main()) ``` ### Streamable HTTP transport diff --git a/src/agent_kernel/drivers/mcp.py b/src/agent_kernel/drivers/mcp.py index e84b23c..0af6a29 100644 --- a/src/agent_kernel/drivers/mcp.py +++ b/src/agent_kernel/drivers/mcp.py @@ -170,6 +170,11 @@ async def _run_with_retry( except DriverError: raise except Exception as exc: + # Broad catch is intentional: exceptions at this level are + # session/transport failures (connection refused, EOF, timeout). + # MCP tool-level application errors are returned as isError=True + # responses and converted to DriverError before reaching this + # handler — they never appear as Python exceptions here. last_exc = exc reason = str(last_exc) if last_exc is not None else "unknown transport failure" diff --git a/src/agent_kernel/drivers/mcp_support.py b/src/agent_kernel/drivers/mcp_support.py index ba2d7bc..4bce876 100644 --- a/src/agent_kernel/drivers/mcp_support.py +++ b/src/agent_kernel/drivers/mcp_support.py @@ -64,7 +64,7 @@ def normalize_call_result(result: Any, *, operation: str, driver_id: str) -> Any return {"structured_content": structured, "content": content} if structured is not None: return structured - return content if content else {} + return content def import_optional(module_name: str) -> Any: From 9421a2b30cd47fb702aba00dff078be84c1b9571 Mon Sep 17 00:00:00 2001 From: Diogo Andre Santos Date: Sun, 12 Apr 2026 13:15:38 +0100 Subject: [PATCH 3/3] fix: harden MCPDriver based on human review + MCP 1.26 SDK pass MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Blocker fixes: - tests/test_mcp_driver.py: strengthen integration test assertion for real add(3,4) result; assert structured_content.result==7 and content[0].text=='7' rather than bare 'is not None' - mcp.py/_run_with_retry: handle McpError (protocol-level rejection) immediately as DriverError; McpError is not retryable — the server processed and rejected the request Major fixes: - mcp.py/discover(): paginate tools/list via _fetch_all_tools; loop on nextCursor until exhausted to avoid silent capability truncation on large MCP servers - mcp_support.py/ToolSpec + discover(): forward ToolAnnotations hints (readOnlyHint, destructiveHint, idempotentHint) to ToolSpec; derive SafetyClass from them via _infer_safety_class(); safety_class_map still overrides. Eliminates always-READ misclassification of destructive tools - mcp_support.py/call_tool(): forward read_timeout_seconds from constraints to ClientSession.call_tool(); prevents indefinite hangs over HTTP - mcp.py/_run_with_retry: document at-least-once delivery semantics for HTTP transport; advise callers to set max_retries=0 for WRITE/DESTRUCTIVE - pyproject.toml: tighten mcp lower bound to >=1.6 in both [mcp] extra and [dev] extras; ToolAnnotations, outputSchema, and nextCursor require this release line Minor fixes: - mcp_support.py/ToolSpec: forward Tool.outputSchema for downstream use by firewall budget/redaction rules - docs/integrations.md: expand HTTP section with full async def main() discover() + register + route example; add at-least-once warning inline --- docs/integrations.md | 37 ++++++++++-- pyproject.toml | 4 +- src/agent_kernel/drivers/mcp.py | 76 +++++++++++++++++++++---- src/agent_kernel/drivers/mcp_support.py | 30 +++++++--- tests/test_mcp_driver.py | 7 ++- 5 files changed, 126 insertions(+), 28 deletions(-) diff --git a/docs/integrations.md b/docs/integrations.md index 73809e7..37d681e 100644 --- a/docs/integrations.md +++ b/docs/integrations.md @@ -47,13 +47,40 @@ asyncio.run(main()) ### Streamable HTTP transport ```python +import asyncio + +from agent_kernel import CapabilityRegistry, Kernel, StaticRouter from agent_kernel.drivers.mcp import MCPDriver -http_driver = MCPDriver.from_http( - url="https://example.com/mcp", - server_name="remote-tools", - max_retries=1, -) + +async def main() -> None: + registry = CapabilityRegistry() + router = StaticRouter(fallback=[]) + kernel = Kernel(registry=registry, router=router) + + # Connect to a remote Streamable HTTP MCP server. + # Note: max_retries > 0 creates at-least-once delivery semantics for + # tools/call — if a connection drops after the server processes the + # request but before the response arrives, the call will be repeated. + # Ensure target tools are idempotent, or set max_retries=0 for + # WRITE/DESTRUCTIVE capabilities. + driver = MCPDriver.from_http( + url="https://example.com/mcp", + server_name="remote-tools", + max_retries=1, + ) + kernel.register_driver(driver) + + # Discover tools and register them as capabilities. + capabilities = await driver.discover(namespace="remote") + registry.register_many(capabilities) + + # Route each discovered capability to this MCP driver. + for capability in capabilities: + router.add_route(capability.capability_id, [driver.driver_id]) + + +asyncio.run(main()) ``` ### Notes diff --git a/pyproject.toml b/pyproject.toml index 39f8bf9..53ec1e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,9 +38,9 @@ dev = [ "ruff>=0.4", "mypy>=1.10", "httpx>=0.27", - "mcp>=1.0", + "mcp>=1.6", ] -mcp = ["mcp>=1.0"] +mcp = ["mcp>=1.6"] otel = ["opentelemetry-api>=1.20"] [tool.hatch.build.targets.wheel] diff --git a/src/agent_kernel/drivers/mcp.py b/src/agent_kernel/drivers/mcp.py index 0af6a29..caba34e 100644 --- a/src/agent_kernel/drivers/mcp.py +++ b/src/agent_kernel/drivers/mcp.py @@ -11,6 +11,7 @@ from .base import ExecutionContext from .mcp_support import ( SessionFactory, + ToolSpec, build_http_session_factory, build_stdio_session_factory, call_tool, @@ -18,6 +19,27 @@ normalize_call_result, ) +# Lazy import of McpError — only available when the mcp optional dep is installed. +# If mcp is absent, factory methods raise ImportError before any session is created, +# so _McpError will never be None on a live driver instance. +try: + from mcp.shared.exceptions import McpError as _McpError +except ImportError: # pragma: no cover + _McpError = None # type: ignore[assignment,misc] + + +def _infer_safety_class(spec: ToolSpec) -> SafetyClass: + """Infer a SafetyClass from MCP ToolAnnotations hints. + + Uses a conservative default of READ when annotations are absent. + The caller's safety_class_map takes precedence over the inferred value. + """ + if spec.destructive_hint: + return SafetyClass.DESTRUCTIVE + if spec.read_only_hint: + return SafetyClass.READ + return SafetyClass.READ + class MCPDriver: """A driver that invokes capabilities via MCP tools/call.""" @@ -92,19 +114,20 @@ async def discover( namespace: str | None = None, safety_class_map: dict[str, SafetyClass] | None = None, ) -> list[Capability]: - """Discover MCP tools and convert them to capabilities.""" - tool_list = await self._run_with_retry( + """Discover MCP tools across all pages and convert them to capabilities.""" + tools = await self._run_with_retry( operation_name="tools/list", - action=lambda session: session.list_tools(), + action=self._fetch_all_tools, ) capabilities: list[Capability] = [] - for spec in extract_tool_specs(tool_list): + for spec in extract_tool_specs(tools): capability_id = f"{namespace}.{spec.name}" if namespace else spec.name + inferred = _infer_safety_class(spec) safety_class = ( - safety_class_map.get(spec.name, SafetyClass.READ) + safety_class_map.get(spec.name, inferred) if safety_class_map is not None - else SafetyClass.READ + else inferred ) capabilities.append( Capability( @@ -121,14 +144,34 @@ async def discover( ) return capabilities + async def _fetch_all_tools(self, session: Any) -> list[Any]: + """Paginate tools/list to exhaustion and return a flat list of Tool objects.""" + all_tools: list[Any] = [] + cursor: str | None = None + while True: + result = await session.list_tools(cursor=cursor) + all_tools.extend(getattr(result, "tools", []) or []) + cursor = getattr(result, "nextCursor", None) + if not cursor: + break + return all_tools + async def execute(self, ctx: ExecutionContext) -> RawResult: """Execute an MCP tool call for the given capability context.""" operation = str(ctx.args.get("operation", ctx.capability_id)) params = {k: v for k, v in ctx.args.items() if k != "operation"} # Apply policy constraints as default arguments, without overriding explicit args. + # read_timeout_seconds is an SDK control parameter — applied to the session call + # directly rather than forwarded to the tool as an argument. + read_timeout_seconds_raw = ctx.constraints.get("read_timeout_seconds") for key, value in ctx.constraints.items(): - params.setdefault(key, value) + if key != "read_timeout_seconds": + params.setdefault(key, value) + + read_timeout_seconds: float | None = ( + float(read_timeout_seconds_raw) if read_timeout_seconds_raw is not None else None + ) result = await self._run_with_retry( operation_name=f"tools/call:{operation}", @@ -136,6 +179,7 @@ async def execute(self, ctx: ExecutionContext) -> RawResult: session, operation=operation, params=params, + read_timeout_seconds=read_timeout_seconds, ), ) @@ -170,11 +214,19 @@ async def _run_with_retry( except DriverError: raise except Exception as exc: - # Broad catch is intentional: exceptions at this level are - # session/transport failures (connection refused, EOF, timeout). - # MCP tool-level application errors are returned as isError=True - # responses and converted to DriverError before reaching this - # handler — they never appear as Python exceptions here. + # McpError is a protocol-level rejection (tool not found, auth + # failure, invalid params) — the server processed and rejected the + # request. It is not retryable; surface it immediately as DriverError. + if _McpError is not None and isinstance(exc, _McpError): + raise DriverError( + f"MCPDriver '{self._driver_id}' received a protocol error " + f"during {operation_name}: {exc}" + ) from exc + # All other exceptions are session/transport failures (connection + # refused, EOF, timeout) and are retryable for HTTP transport. + # Note: HTTP retries create at-least-once delivery semantics for + # tools/call. Callers using WRITE/DESTRUCTIVE capabilities over HTTP + # should ensure the target tool is idempotent, or set max_retries=0. last_exc = exc reason = str(last_exc) if last_exc is not None else "unknown transport failure" diff --git a/src/agent_kernel/drivers/mcp_support.py b/src/agent_kernel/drivers/mcp_support.py index 4bce876..048ec5a 100644 --- a/src/agent_kernel/drivers/mcp_support.py +++ b/src/agent_kernel/drivers/mcp_support.py @@ -6,6 +6,7 @@ from collections.abc import AsyncIterator, Callable from contextlib import AbstractAsyncContextManager, asynccontextmanager from dataclasses import dataclass +from datetime import timedelta from typing import Any from ..errors import DriverError @@ -19,16 +20,26 @@ class ToolSpec: name: str description: str - - -async def call_tool(session: Any, *, operation: str, params: dict[str, Any]) -> Any: + read_only_hint: bool = False + destructive_hint: bool = False + idempotent_hint: bool = False + output_schema: dict[str, Any] | None = None + + +async def call_tool( + session: Any, + *, + operation: str, + params: dict[str, Any], + read_timeout_seconds: float | None = None, +) -> Any: """Call an MCP tool via tools/call.""" - return await session.call_tool(operation, arguments=params) + timeout = timedelta(seconds=read_timeout_seconds) if read_timeout_seconds is not None else None + return await session.call_tool(operation, arguments=params, read_timeout_seconds=timeout) -def extract_tool_specs(tool_list_response: Any) -> list[ToolSpec]: - """Extract tool metadata from a tools/list response payload.""" - tools = getattr(tool_list_response, "tools", []) +def extract_tool_specs(tools: list[Any]) -> list[ToolSpec]: + """Extract tool metadata from a flat list of MCP Tool objects.""" if not isinstance(tools, list): return [] specs: list[ToolSpec] = [] @@ -36,10 +47,15 @@ def extract_tool_specs(tool_list_response: Any) -> list[ToolSpec]: name = getattr(tool, "name", None) if not isinstance(name, str) or not name: continue + ann = getattr(tool, "annotations", None) specs.append( ToolSpec( name=name, description=str(getattr(tool, "description", "") or ""), + read_only_hint=bool(getattr(ann, "readOnlyHint", False)), + destructive_hint=bool(getattr(ann, "destructiveHint", False)), + idempotent_hint=bool(getattr(ann, "idempotentHint", False)), + output_schema=getattr(tool, "outputSchema", None), ) ) return specs diff --git a/tests/test_mcp_driver.py b/tests/test_mcp_driver.py index 5488950..81d1a02 100644 --- a/tests/test_mcp_driver.py +++ b/tests/test_mcp_driver.py @@ -37,13 +37,14 @@ def __init__( self._call_error = call_error self.calls: list[tuple[str, dict[str, Any]]] = [] - async def list_tools(self) -> ListToolsResult: + async def list_tools(self, cursor: str | None = None) -> ListToolsResult: return ListToolsResult(tools=self._tools) async def call_tool( self, operation: str, arguments: dict[str, Any], + read_timeout_seconds: Any = None, ) -> CallToolResult: self.calls.append((operation, arguments)) if self._call_error is not None: @@ -292,4 +293,6 @@ async def in_memory_factory() -> AsyncIterator[ClientSession]: args={"operation": "add", "a": 3, "b": 4}, ) result = await driver.execute(ctx) - assert result.data is not None + assert isinstance(result.data, dict) + assert result.data["structured_content"]["result"] == 7 + assert result.data["content"][0]["text"] == "7"