diff --git a/AGENTS.md b/AGENTS.md index 45cfb9a..03622ea 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -52,7 +52,7 @@ Use these terms consistently. Never substitute synonyms: - Error messages are part of the contract — tests must assert both exception type and message. - Keep modules ≤ 300 lines. Split if needed. - No randomness in matching, routing, or summarization. Deterministic outputs always. -- No new dependencies without justification. The dep list is intentionally minimal (`httpx` only). +- No new dependencies without justification. The dep list is intentionally minimal (`httpx`, `pydantic`). ## Security rules diff --git a/CHANGELOG.md b/CHANGELOG.md index 700df3c..24d3a90 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,28 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- LLM tool-format adapters and middleware (`agent_kernel.adapters`): `OpenAIMiddleware` (OpenAI + Responses API + Chat Completions, auto-detected on input) and `AnthropicMiddleware` (Anthropic + Messages with `cache_control` support). Both translate `Capability` objects to vendor tool + schemas, route tool calls through the full kernel pipeline (grant → invoke → firewall → trace), + and surface kernel errors (`PolicyDenied`, `CapabilityNotFound`, `DriverError`) as tool-result + errors so the LLM can react. Pre/post hooks (`intercept_tool_call`, `intercept_tool_result`, + sync or async) support logging, metrics, approval gates, and per-call justification injection. + Zero runtime dependency on the `openai` / `anthropic` SDK packages. (#55, #50, #40) +- New `Capability` fields for LLM adapters: `parameters_model: type[pydantic.BaseModel] | None` + (input schema source + validation), `parameters_schema: dict | None` (raw JSON Schema escape + hatch), and `tool_hints: ToolHints | None` (vendor hints — Anthropic `cache_control`, OpenAI + `strict` mode). All default to ``None``; existing capabilities and tests are unaffected. +- New `ToolHints` dataclass and `OpenAIMiddleware` / `AnthropicMiddleware` top-level exports. +- New `AdapterParseError(AgentKernelError)` exception raised by adapter parse / validation + helpers (`tool_call_to_request`, `tool_use_to_request`, `make_namespace_safe_name`) instead + of bare `ValueError`. Satisfies `AGENTS.md`'s "no bare ValueError to callers" rule and + gives consumers a stable adapter-specific exception type. Also catches capability IDs that + contain the reserved OpenAI namespace separator `__` (which would otherwise produce + colliding tool names). +- `Kernel.list_capabilities()` convenience accessor returning every registered capability in + registration order. Used by the new adapters but generally useful for tooling that needs to + enumerate the registry without keyword search. - Declarative policy engine (`DeclarativePolicyEngine`) that loads rules from YAML or TOML files. Rules are evaluated top-down with first-match-wins semantics; supports `safety_class`, `sensitivity`, `roles`, `attributes`, and `min_justification` match conditions. (#42) @@ -28,6 +50,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Example policy files in `examples/policies/` (YAML and TOML formats). ### Changed +- Runtime dependencies now include `pydantic>=2` in addition to `httpx`. Pydantic is used by the new + `agent_kernel.adapters` package for JSON-Schema generation and argument validation when a + `Capability` declares a `parameters_model`. Existing kernel behavior is unchanged; pydantic is not + imported at module load by anything outside the adapters. - `PolicyEngine` protocol no longer requires `explain()`. Engines that need to support `Kernel.explain_denial()` should implement the new `ExplainingPolicyEngine` protocol. Built-in engines satisfy both. This avoids a breaking typing change for downstream implementers. diff --git a/docs/architecture.md b/docs/architecture.md index 5a7a6af..b865678 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -87,3 +87,24 @@ Stores full results by opaque handle ID with TTL. `expand()` supports pagination ### TraceStore Records every `ActionTrace`. `explain(action_id)` returns the full audit record. + +### Adapters (`agent_kernel.adapters`) +Vendor-specific tool-format adapters that translate between `Capability` objects +and the tool shapes used by LLM provider APIs: + +- **`OpenAIMiddleware`** — emits OpenAI tool definitions (Responses API or Chat + Completions shape), parses `response.output` / `message.tool_calls`, and + returns `function_call_output` / tool-result messages. Dotted capability IDs + map to `namespace__function` (OpenAI tool names cannot contain `.`). +- **`AnthropicMiddleware`** — emits Anthropic tool definitions with optional + `cache_control` blocks, parses `tool_use` content blocks, and returns + `tool_result` content blocks. Dotted capability IDs are preserved as-is. + +Both classes share `BaseToolMiddleware`, which owns hook registration +(`intercept_tool_call`, `intercept_tool_result`), pre/post dispatch (sync or +async), and conversion of kernel exceptions (`PolicyDenied`, +`CapabilityNotFound`, `DriverError`) into tool-result errors the LLM can react +to. Input arguments are validated against `Capability.parameters_model` +(pydantic) when present. **Zero runtime dependency** on the `openai` / +`anthropic` SDK packages. See [`docs/integrations.md`](integrations.md) for +usage examples. diff --git a/docs/integrations.md b/docs/integrations.md index 37d681e..d8e4a18 100644 --- a/docs/integrations.md +++ b/docs/integrations.md @@ -128,3 +128,202 @@ When mapping MCP tools to capabilities, prefer task-shaped names: | `write_file` | `fs.write_file` | WRITE | | `delete_file` | `fs.delete_file` | DESTRUCTIVE | | `execute_code` | `sandbox.run_code` | DESTRUCTIVE | + +## LLM tool-format adapters + +`agent_kernel.adapters` converts `Capability` objects into the tool shapes +expected by OpenAI and Anthropic, and routes the matching tool-call objects +back through the kernel pipeline (grant → invoke → firewall → trace). The +adapters are pure dict transforms — there is **no runtime dependency** on the +`openai` or `anthropic` SDK packages. + +### Input schemas with pydantic + +Capabilities advertise their input schema via two optional fields on +`Capability`: + +- `parameters_model: type[pydantic.BaseModel] | None` — pydantic model. The + adapter calls `.model_json_schema()` and validates tool-call arguments + against the model before invocation. +- `parameters_schema: dict | None` — raw JSON Schema, used verbatim. No + argument validation is performed (use `parameters_model` for that). + +`Capability.allowed_fields` is an **output redaction** control consumed by the +firewall — it is *not* used as an input schema source. + +```python +from pydantic import BaseModel, Field + +from agent_kernel import Capability, SafetyClass + + +class ListInvoicesArgs(BaseModel): + customer_id: str + limit: int = Field(default=10, ge=1, le=100) + + +list_invoices = Capability( + capability_id="billing.list_invoices", + name="List Invoices", + description="List invoices for a customer", + safety_class=SafetyClass.READ, + parameters_model=ListInvoicesArgs, +) +``` + +### OpenAI middleware + +```python +import asyncio + +from agent_kernel import Kernel, OpenAIMiddleware, Principal + + +async def main() -> None: + kernel = Kernel(registry=registry, ...) + principal = Principal(principal_id="agent-1", roles=["reader"]) + mw = OpenAIMiddleware(kernel, principal) + + tools = mw.get_tools() # → list[dict] for OpenAI SDK + # response = await openai_client.responses.create(model=..., tools=tools, ...) + # outputs = await mw.handle_tool_calls(response.output) + # → list of {"type": "function_call_output", "call_id", "output"} dicts. + + +asyncio.run(main()) +``` + +The default output shape is **OpenAI Responses API** +(`function_call_output`). Use `format="chat_completions"` to emit nested +`{"type": "function", "function": {...}}` tool definitions and +`{"role": "tool", ...}` result messages instead: + +```python +mw = OpenAIMiddleware(kernel, principal, format="chat_completions") +``` + +`handle_tool_calls` auto-detects the input shape per call regardless of the +configured output format, so you can pass either Responses-API +`response.output` items or Chat-Completions `message.tool_calls` items. + +#### Namespace mapping + +OpenAI tool names cannot contain `.`, so dotted capability IDs are mapped to +double-underscore form on the way out and restored on the way back: + +| Capability ID | OpenAI tool name | +|---|---| +| `billing.list_invoices` | `billing__list_invoices` | +| `billing.invoices.list` | `billing__invoices__list` | + +Capability IDs that already contain `__` cannot be round-tripped unambiguously +(`a__b` and `a.b` would both produce the OpenAI tool name `a__b`). The adapter +rejects them at tool-emit time with an `AdapterParseError` rather than +silently emitting colliding tools. + +#### Strict mode + +Set `Capability.tool_hints = ToolHints(strict=True)` to emit the tool +definition with OpenAI's `strict: true` flag. The adapter normalises the +schema so every property is required and `additionalProperties` is `false` +at every level. If normalisation fails (e.g. a schema feature OpenAI strict +mode does not accept) the adapter falls back to non-strict and emits a +warning. + +**Strict mode caveats** + +OpenAI strict mode requires every property be listed in `required`. The +adapter's normaliser enforces this unconditionally. That means pydantic +fields with non-`None` defaults — which pydantic itself emits as +*not* required — will be forced into `required` after normalisation. The +LLM is then expected to always include the field even though pydantic would +fall back to the default if it were omitted. + +To express a truly-optional field under strict mode, use the `Optional[T]` +pattern (with `None` as the default): + +```python +class ListInvoicesArgs(BaseModel): + customer_id: str # required, str + limit: int = 10 # forced into required by strict mode + cursor: str | None = None # required + nullable (LLM can pass null) +``` + +Pydantic emits `Optional[str] = None` (or `str | None = None`) as +`{"anyOf": [{"type": "string"}, {"type": "null"}]}`. OpenAI strict mode +accepts `null` as a valid value for such fields, so the LLM can effectively +"omit" them by passing `null`. + +### Anthropic middleware + +```python +import asyncio + +from agent_kernel import AnthropicMiddleware, Kernel, Principal + + +async def main() -> None: + kernel = Kernel(registry=registry, ...) + principal = Principal(principal_id="agent-1", roles=["reader"]) + mw = AnthropicMiddleware(kernel, principal) + + tools = mw.get_tools() # → list[dict] for Anthropic SDK + # message = await anthropic_client.messages.create(model=..., tools=tools, ...) + # tool_results = await mw.handle_tool_uses(message.content) + # → list of {"type": "tool_result", "tool_use_id", "content"} blocks. + + +asyncio.run(main()) +``` + +#### Prompt cache control + +Set `Capability.tool_hints = ToolHints(cache_control={"type": "ephemeral"})` +to attach Anthropic's prompt-cache control block to that capability's tool +definition. To apply a default to every tool that does not specify its own, +pass `default_cache_control` to the middleware: + +```python +mw = AnthropicMiddleware( + kernel, + principal, + default_cache_control={"type": "ephemeral"}, +) +``` + +### Hooks (pre/post invocation) + +Both middlewares accept synchronous or asynchronous callbacks via +`intercept_tool_call(callback)` and `intercept_tool_result(callback)`. Hooks +fire in registration order. Pre-hooks receive a mutable `ToolCallEvent` +(useful for logging, metrics, approval gates, injecting `justification` for +WRITE/DESTRUCTIVE calls); post-hooks receive a `ToolResultEvent` carrying +either the kernel `Frame` or an error string. + +```python +async def audit(event): + log.info("tool_call", capability=event.capability_id, principal=event.principal_id) + +def gate(event): + if event.capability_id.startswith("billing.delete"): + event.aborted = True + event.abort_reason = "deletions require manual approval" + +mw.intercept_tool_call(audit) +mw.intercept_tool_call(gate) +``` + +Setting `event.aborted = True` skips kernel invocation and produces a +tool-result error block containing `event.abort_reason`. Setting +`event.justification = "..."` lets a hook supply the per-call justification +the policy engine requires for WRITE/DESTRUCTIVE capabilities. Per-call +overrides can also be threaded through arguments as `_justification` (the +adapter pops it before passing args to the driver). + +### Errors are tool results, not exceptions + +`PolicyDenied`, `CapabilityNotFound`, `DriverError`, argument-validation +failures, and hook abort signals are all returned to the LLM as a tool result +with `error: true` (Anthropic also sets `is_error: true`). Raised exceptions +would crash the surrounding agent loop; the LLM cannot react to an +exception. diff --git a/pyproject.toml b/pyproject.toml index 33a244d..0159973 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,10 @@ classifiers = [ "Topic :: Security", "Topic :: Software Development :: Libraries :: Python Modules", ] -dependencies = ["httpx>=0.27"] +dependencies = [ + "httpx>=0.27", + "pydantic>=2", +] [project.urls] Homepage = "https://github.com/dgenio/agent-kernel" diff --git a/src/agent_kernel/__init__.py b/src/agent_kernel/__init__.py index 73fda47..cdef01c 100644 --- a/src/agent_kernel/__init__.py +++ b/src/agent_kernel/__init__.py @@ -25,6 +25,10 @@ from agent_kernel import HandleStore, TraceStore +LLM tool-format adapters:: + + from agent_kernel import OpenAIMiddleware, AnthropicMiddleware + Errors:: from agent_kernel import ( @@ -35,12 +39,14 @@ ) """ +from .adapters import AnthropicMiddleware, OpenAIMiddleware from .drivers.base import Driver, ExecutionContext from .drivers.http import HTTPDriver from .drivers.mcp import MCPDriver from .drivers.memory import InMemoryDriver, make_billing_driver from .enums import SafetyClass, SensitivityTag from .errors import ( + AdapterParseError, AgentKernelError, CapabilityAlreadyRegistered, CapabilityNotFound, @@ -76,6 +82,7 @@ RawResult, ResponseMode, RoutePlan, + ToolHints, ) from .policy import DefaultPolicyEngine, ExplainingPolicyEngine, PolicyEngine from .policy_dsl import DeclarativePolicyEngine, PolicyMatch, PolicyRule @@ -111,10 +118,12 @@ "ResponseMode", "RoutePlan", "ActionTrace", + "ToolHints", # enums "SafetyClass", "SensitivityTag", # errors + "AdapterParseError", "AgentKernelError", "CapabilityAlreadyRegistered", "CapabilityNotFound", @@ -152,4 +161,7 @@ # stores "HandleStore", "TraceStore", + # adapters + "AnthropicMiddleware", + "OpenAIMiddleware", ] diff --git a/src/agent_kernel/adapters/__init__.py b/src/agent_kernel/adapters/__init__.py new file mode 100644 index 0000000..ac20246 --- /dev/null +++ b/src/agent_kernel/adapters/__init__.py @@ -0,0 +1,35 @@ +"""LLM tool-format adapters and middleware. + +The adapter layer translates between :class:`~agent_kernel.Capability` objects +and vendor-specific tool shapes (OpenAI Responses / Chat Completions, +Anthropic Messages) without depending on the vendor SDKs at runtime. The +middleware classes also route a vendor's tool-call objects through the full +kernel pipeline (grant → invoke → firewall → trace), returning vendor-shaped +tool-result objects. + +Two middleware classes share a common base (:class:`BaseToolMiddleware`) which +owns hook registration, dispatch, and error-as-result conversion. +""" + +from __future__ import annotations + +from ._base import ( + BaseToolMiddleware, + ToolCallEvent, + ToolCallHook, + ToolResultEvent, + ToolResultHook, +) +from .anthropic import AnthropicMiddleware +from .openai import OpenAIMiddleware, OpenAIToolFormat + +__all__ = [ + "AnthropicMiddleware", + "BaseToolMiddleware", + "OpenAIMiddleware", + "OpenAIToolFormat", + "ToolCallEvent", + "ToolCallHook", + "ToolResultEvent", + "ToolResultHook", +] diff --git a/src/agent_kernel/adapters/_base.py b/src/agent_kernel/adapters/_base.py new file mode 100644 index 0000000..8be85c3 --- /dev/null +++ b/src/agent_kernel/adapters/_base.py @@ -0,0 +1,459 @@ +"""Shared plumbing for LLM tool-format adapters. + +This module is private. Stable adapter API is exported from +:mod:`agent_kernel.adapters` (and re-exported from :mod:`agent_kernel`). + +Both :class:`~agent_kernel.adapters.openai.OpenAIMiddleware` and +:class:`~agent_kernel.adapters.anthropic.AnthropicMiddleware` build on top of +:class:`BaseToolMiddleware`, which owns: + +- hook registration and dispatch (sync or async callables) +- the request → grant → invoke → format flow +- error-as-result conversion for kernel-side failures +- the canonical :class:`Frame` → JSON payload shape +- pydantic-driven schema generation and argument validation +- the dot-notation ↔ ``namespace__function`` round-trip used by OpenAI +""" + +from __future__ import annotations + +import copy +import inspect +import logging +from collections.abc import Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, ClassVar, Literal + +from pydantic import ValidationError + +from ..errors import ( + AdapterParseError, + AgentKernelError, + CapabilityNotFound, + DriverError, + PolicyDenied, +) +from ..models import ( + Capability, + CapabilityRequest, + Frame, + Principal, + ResponseMode, +) + +if TYPE_CHECKING: + from ..kernel import Kernel + +logger = logging.getLogger(__name__) + +# Sentinel returned to vendors when a tool call is aborted by a hook. +_ABORT_PREFIX = "Aborted by pre-invocation hook" + +# Used to escape dots in capability IDs for vendors that reject "." in tool names. +_NAMESPACE_SEP = "__" + + +# ── Event objects ───────────────────────────────────────────────────────────── + + +@dataclass(slots=True) +class ToolCallEvent: + """Event delivered to ``intercept_tool_call`` hooks before kernel invocation. + + Hooks may mutate :attr:`args` to override arguments, set + :attr:`justification` to inject a justification (required for WRITE and + DESTRUCTIVE capabilities), or set :attr:`aborted` to skip the call. A + skipped call still produces a tool-result for the LLM — the + :attr:`abort_reason` is included. + """ + + capability_id: str + args: dict[str, Any] + principal_id: str + vendor: Literal["openai", "anthropic"] + call_id: str + justification: str = "" + aborted: bool = False + abort_reason: str = "" + + +@dataclass(slots=True) +class ToolResultEvent: + """Event delivered to ``intercept_tool_result`` hooks after kernel invocation. + + Exactly one of :attr:`frame` or :attr:`error` is non-``None``. Hooks may + replace :attr:`frame` (e.g. to apply caching transformations) or override + :attr:`error` (e.g. to redact internal detail before reaching the LLM). + """ + + capability_id: str + principal_id: str + vendor: Literal["openai", "anthropic"] + call_id: str + frame: Frame | None = None + error: str | None = None + + +ToolCallHook = Callable[[ToolCallEvent], Any] +"""Pre-invocation hook. May return ``None`` or an awaitable.""" + +ToolResultHook = Callable[[ToolResultEvent], Any] +"""Post-invocation hook. May return ``None`` or an awaitable.""" + + +# ── Schema helpers ──────────────────────────────────────────────────────────── + + +def build_input_schema(capability: Capability) -> dict[str, Any]: + """Derive a JSON Schema for the capability's input arguments. + + Resolution order: + + 1. :attr:`Capability.parameters_model` — pydantic ``model_json_schema()``. + 2. :attr:`Capability.parameters_schema` — used verbatim. + 3. Fallback: permissive ``{"type": "object", "additionalProperties": true}``. + + Note: ``allowed_fields`` is an output-redaction control and is intentionally + ignored here. + """ + if capability.parameters_model is not None: + # pydantic 2's mode="validation" mirrors what the model accepts as input. + return capability.parameters_model.model_json_schema(mode="validation") + if capability.parameters_schema is not None: + # Deep-copy so downstream mutation of nested objects (e.g. tweaking + # ``properties[...]["type"]``) does not leak into the registry. + return copy.deepcopy(capability.parameters_schema) + return {"type": "object", "additionalProperties": True} + + +def normalize_for_openai_strict(schema: dict[str, Any]) -> dict[str, Any]: + """Best-effort normalisation of *schema* for OpenAI strict mode. + + OpenAI ``strict: true`` requires every object schema to: + + - list every property in ``required`` + - set ``additionalProperties: false`` + + This walker enforces both, recursively. It does not flatten ``$ref`` / + ``$defs`` (OpenAI's strict mode accepts those). Returns a deep-copied + schema so the caller's input is untouched. + """ + result = _normalize_strict(schema) + # ``_normalize_strict`` returns ``Any`` because it handles three node + # shapes (dict, list, scalar); at this top-level entry the input is a + # dict so the result is too. + assert isinstance(result, dict) + return result + + +def _normalize_strict(node: Any) -> Any: + if isinstance(node, dict): + # Recurse first so nested objects pick up the same treatment. + out: dict[str, Any] = {k: _normalize_strict(v) for k, v in node.items()} + if out.get("type") == "object": + properties = out.get("properties") + if isinstance(properties, dict): + out["required"] = list(properties.keys()) + out["additionalProperties"] = False + return out + if isinstance(node, list): + return [_normalize_strict(v) for v in node] + return node + + +def validate_input(capability: Capability, args: dict[str, Any]) -> dict[str, Any]: + """Validate *args* against the capability's input model, if one is set. + + Returns the validated/coerced dict (pydantic may coerce types such as + string → int per the model's declared types). When the capability has no + :attr:`parameters_model`, returns *args* unchanged — raw schemas are not + validated (use a model if validation matters). + + Raises: + ValidationError: If pydantic rejects the arguments. Callers in this + module catch this and surface the failure as a tool-result error. + """ + if capability.parameters_model is None: + return args + model = capability.parameters_model.model_validate(args) + # ``mode="python"`` keeps nested model instances as dicts the kernel/driver + # already understand; ``by_alias=True`` is unnecessary here because the + # driver consumes the original field names. + return model.model_dump(mode="python") + + +# ── Namespace helpers ───────────────────────────────────────────────────────── + + +def make_namespace_safe_name(capability_id: str) -> str: + """Convert a dotted capability_id into a vendor-safe identifier. + + ``billing.list_invoices`` → ``billing__list_invoices``. The ``__`` separator + is reserved: capability IDs that already contain ``__`` cannot be + round-tripped unambiguously (``"a__b"`` and ``"a.b"`` would both map to + ``"a__b"``), so they are rejected at adapter-emit time rather than + silently producing colliding OpenAI tool names. + + Raises: + AdapterParseError: If *capability_id* contains the reserved ``__`` + separator. Single underscores are fine (``"list_invoices_v2"`` + round-trips cleanly). + """ + if _NAMESPACE_SEP in capability_id: + raise AdapterParseError( + f"Capability ID '{capability_id}' contains the reserved namespace " + f"separator '{_NAMESPACE_SEP}'. The OpenAI adapter would map it to " + f"a tool name that collides with dotted capability IDs (e.g. " + f"'a__b' and 'a.b' both produce 'a__b'). Rename the capability or " + f"strip the double underscore." + ) + return capability_id.replace(".", _NAMESPACE_SEP) + + +def restore_namespace(safe_name: str) -> str: + """Inverse of :func:`make_namespace_safe_name`. + + ``billing__list_invoices`` → ``billing.list_invoices``. + """ + return safe_name.replace(_NAMESPACE_SEP, ".") + + +# ── Payload helpers ─────────────────────────────────────────────────────────── + + +def frame_to_payload(frame: Frame) -> dict[str, Any]: + """Canonical JSON-serialisable shape for a kernel :class:`Frame`. + + Used by both OpenAI and Anthropic adapters as the tool-result body. The + shape is deterministic so LLM prompt caches remain stable. + """ + handle: dict[str, Any] | None = None + if frame.handle is not None: + handle = { + "handle_id": frame.handle.handle_id, + "total_rows": frame.handle.total_rows, + } + return { + "action_id": frame.action_id, + "capability_id": frame.capability_id, + "response_mode": frame.response_mode, + "facts": list(frame.facts), + "table_preview": list(frame.table_preview), + "warnings": list(frame.warnings), + "handle": handle, + } + + +def error_to_payload(*, capability_id: str, error: str) -> dict[str, Any]: + """Canonical JSON-serialisable shape for a tool-result error.""" + return { + "error": True, + "capability_id": capability_id, + "message": error, + } + + +# ── Middleware base ─────────────────────────────────────────────────────────── + + +@dataclass(slots=True) +class PreparedCall: + """A parsed tool call ready for dispatch through the kernel pipeline. + + Adapters parse vendor-specific shapes into this neutral form before + handing them to :meth:`BaseToolMiddleware._dispatch_one`. + """ + + call_id: str + capability_id: str + args: dict[str, Any] + + +class BaseToolMiddleware: + """Shared base class for vendor-specific tool middleware. + + Subclasses implement only the vendor-specific shape adapters + (``capabilities_to_tools``, ``handle_*``); the request/grant/invoke flow, + hook dispatch, and error handling live here. + """ + + vendor: ClassVar[Literal["openai", "anthropic"]] + + def __init__( + self, + kernel: Kernel, + principal: Principal, + *, + response_mode: ResponseMode = "summary", + ) -> None: + self._kernel = kernel + self._principal = principal + self._response_mode: ResponseMode = response_mode + self._pre_hooks: list[ToolCallHook] = [] + self._post_hooks: list[ToolResultHook] = [] + + # ── Hooks ────────────────────────────────────────────────────────────── + + def intercept_tool_call(self, callback: ToolCallHook) -> None: + """Register a pre-invocation hook. + + Hooks are dispatched in registration order. Sync and async callables + are both supported; an awaitable return value is awaited. + """ + self._pre_hooks.append(callback) + + def intercept_tool_result(self, callback: ToolResultHook) -> None: + """Register a post-invocation hook. + + Hooks are dispatched in registration order. Sync and async callables + are both supported; an awaitable return value is awaited. + """ + self._post_hooks.append(callback) + + # ── Internal — capability lookup ─────────────────────────────────────── + + def _list_capabilities(self) -> list[Capability]: + return self._kernel.list_capabilities() + + def _get_capability(self, capability_id: str) -> Capability | None: + for cap in self._kernel.list_capabilities(): + if cap.capability_id == capability_id: + return cap + return None + + # ── Internal — dispatch ──────────────────────────────────────────────── + + async def _dispatch_one( + self, + prepared: PreparedCall, + *, + batch_justification: str, + ) -> ToolResultEvent: + """Run a single prepared call through hooks → kernel → hooks.""" + principal_id = self._principal.principal_id + per_call_justification = "" + args = dict(prepared.args) + if "_justification" in args: + value = args.pop("_justification") + per_call_justification = str(value) if value is not None else "" + + event = ToolCallEvent( + capability_id=prepared.capability_id, + args=args, + principal_id=principal_id, + vendor=self.vendor, + call_id=prepared.call_id, + justification=per_call_justification or batch_justification, + ) + + result_event = ToolResultEvent( + capability_id=prepared.capability_id, + principal_id=principal_id, + vendor=self.vendor, + call_id=prepared.call_id, + ) + + try: + await self._fire_pre_hooks(event) + except Exception as exc: # noqa: BLE001 — hook errors become tool errors + result_event.error = f"Pre-invocation hook raised: {exc}" + await self._fire_post_hooks(result_event) + return result_event + + if event.aborted: + reason = event.abort_reason or "no reason provided" + result_event.error = f"{_ABORT_PREFIX}: {reason}" + await self._fire_post_hooks(result_event) + return result_event + + result_event.frame, result_event.error = await self._invoke_capability(event) + await self._fire_post_hooks(result_event) + return result_event + + async def _invoke_capability(self, event: ToolCallEvent) -> tuple[Frame | None, str | None]: + """Run grant + invoke, mapping kernel exceptions to tool-result errors.""" + capability = self._get_capability(event.capability_id) + if capability is None: + return None, (f"Capability '{event.capability_id}' is not registered in this kernel.") + + # Validate arguments against the capability's schema (if any). + try: + validated_args = validate_input(capability, event.args) + except ValidationError as exc: + return None, f"Argument validation failed: {exc}" + + request = CapabilityRequest( + capability_id=event.capability_id, + goal=f"adapter:{self.vendor}", + constraints={}, + ) + + try: + grant = self._kernel.grant_capability( + request, self._principal, justification=event.justification + ) + except PolicyDenied as exc: + return None, f"Policy denied: {exc}" + except CapabilityNotFound as exc: + # Race between list_capabilities and grant; surface cleanly. + return None, f"Capability not found: {exc}" + + try: + frame = await self._kernel.invoke( + grant.token, + principal=self._principal, + args=validated_args, + response_mode=self._response_mode, + ) + except DriverError as exc: + return None, f"Driver error: {exc}" + except AgentKernelError as exc: + return None, f"Kernel error: {exc}" + return frame, None + + # ── Hook dispatch ────────────────────────────────────────────────────── + + async def _fire_pre_hooks(self, event: ToolCallEvent) -> None: + for hook in self._pre_hooks: + await self._await_if_needed(hook(event)) + + async def _fire_post_hooks(self, event: ToolResultEvent) -> None: + for hook in self._post_hooks: + # Hook exceptions during post-processing are logged but never + # crash the surrounding tool-call batch. + try: + await self._await_if_needed(hook(event)) + except Exception as exc: # noqa: BLE001 + logger.warning( + "post_hook_failed", + extra={ + "capability_id": event.capability_id, + "call_id": event.call_id, + "vendor": self.vendor, + "error": str(exc), + }, + ) + + @staticmethod + async def _await_if_needed(value: Any) -> None: + if inspect.isawaitable(value): + await value + + +# Public re-exports so ``from agent_kernel.adapters import X`` resolves +# cleanly even for internals subclasses lean on. +__all__ = [ + "BaseToolMiddleware", + "PreparedCall", + "ToolCallEvent", + "ToolCallHook", + "ToolResultEvent", + "ToolResultHook", + "build_input_schema", + "error_to_payload", + "frame_to_payload", + "make_namespace_safe_name", + "normalize_for_openai_strict", + "restore_namespace", + "validate_input", +] diff --git a/src/agent_kernel/adapters/anthropic.py b/src/agent_kernel/adapters/anthropic.py new file mode 100644 index 0000000..ae43074 --- /dev/null +++ b/src/agent_kernel/adapters/anthropic.py @@ -0,0 +1,273 @@ +"""Anthropic tool-format adapter and middleware. + +Emits Anthropic Messages API tool definitions with optional ``cache_control`` +support, and converts ``tool_use`` content blocks through the kernel pipeline +into ``tool_result`` content blocks. + +Anthropic preserves dotted capability IDs as-is (their tool name field accepts +``[a-zA-Z0-9_.-]``), so no namespace transformation is required. + +No runtime dependency on the ``anthropic`` package — every public function +takes and returns plain dicts. Pydantic (a kernel runtime dep) handles schema +generation and argument validation. +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Any, ClassVar, Literal + +from ..errors import AdapterParseError +from ..models import Capability, CapabilityRequest, ResponseMode +from ._base import ( + BaseToolMiddleware, + PreparedCall, + build_input_schema, + error_to_payload, + frame_to_payload, +) + +if TYPE_CHECKING: + from ..kernel import Kernel + from ..models import Principal + from ._base import ToolResultEvent + +logger = logging.getLogger(__name__) + + +# ── Schema conversion ───────────────────────────────────────────────────────── + + +def capabilities_to_tools( + capabilities: list[Capability], + *, + default_cache_control: dict[str, Any] | None = None, +) -> list[dict[str, Any]]: + """Convert :class:`Capability` objects to Anthropic tool definitions. + + Args: + capabilities: Capabilities to expose as tools. + default_cache_control: Default ``cache_control`` block applied to every + tool that does not specify its own via :attr:`Capability.tool_hints`. + Per-capability ``ToolHints.cache_control`` takes precedence. + + Returns: + List of dicts shaped like ``{"name", "description", "input_schema", + "cache_control"?}`` ready to pass to the Anthropic SDK. + """ + return [ + _capability_to_tool(cap, default_cache_control=default_cache_control) + for cap in capabilities + ] + + +def _capability_to_tool( + capability: Capability, + *, + default_cache_control: dict[str, Any] | None, +) -> dict[str, Any]: + tool: dict[str, Any] = { + "name": capability.capability_id, + "description": _describe(capability), + "input_schema": build_input_schema(capability), + } + cache_control = _resolve_cache_control(capability, default_cache_control) + if cache_control is not None: + tool["cache_control"] = cache_control + return tool + + +def _resolve_cache_control( + capability: Capability, + default_cache_control: dict[str, Any] | None, +) -> dict[str, Any] | None: + """Per-capability ``cache_control`` from ``tool_hints`` wins over the default.""" + if capability.tool_hints is not None and capability.tool_hints.cache_control is not None: + return dict(capability.tool_hints.cache_control) + if default_cache_control is not None: + return dict(default_cache_control) + return None + + +def _describe(capability: Capability) -> str: + """Build a description that surfaces safety/sensitivity to the LLM.""" + parts = [capability.description, f"[safety={capability.safety_class.value}]"] + if capability.sensitivity.value != "NONE": + parts.append(f"[sensitivity={capability.sensitivity.value}]") + return " ".join(parts) + + +# ── tool_use → CapabilityRequest ────────────────────────────────────────────── + + +def tool_use_to_request(tool_use_block: dict[str, Any]) -> CapabilityRequest: + """Convert an Anthropic ``tool_use`` content block to a :class:`CapabilityRequest`. + + Expected input shape:: + + {"type": "tool_use", "id": "toolu_xxx", "name": "billing.list_invoices", + "input": {"customer_id": "..."}} + + Anthropic delivers ``input`` as an object (not a JSON string), so no parsing + is needed beyond a defensive copy. + + Raises: + AdapterParseError: If the block is missing ``name`` or ``input`` has + the wrong type. + """ + name = tool_use_block.get("name") + if not isinstance(name, str) or not name: + raise AdapterParseError( + "Anthropic tool_use block is missing a 'name' field or it is not a string." + ) + raw_input = tool_use_block.get("input", {}) + if raw_input is None: + raw_input = {} + if not isinstance(raw_input, dict): + raise AdapterParseError( + f"Anthropic tool_use 'input' must be an object (got {type(raw_input).__name__})." + ) + return CapabilityRequest( + capability_id=name, + goal="adapter:anthropic", + constraints={}, + ) + + +# ── Frame / error → Anthropic tool_result ───────────────────────────────────── + + +def format_result( + payload: dict[str, Any], + *, + tool_use_id: str, + is_error: bool = False, +) -> dict[str, Any]: + """Wrap a payload dict in an Anthropic ``tool_result`` content block. + + The payload is serialised to a single ``{"type": "text", "text": }`` + content block so the LLM can reason over a stable, parseable shape and so + downstream tool chains preserve content-block structure. + """ + body = json.dumps(payload, ensure_ascii=False, sort_keys=True, default=str) + block: dict[str, Any] = { + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": [{"type": "text", "text": body}], + } + if is_error: + block["is_error"] = True + return block + + +# ── Middleware ──────────────────────────────────────────────────────────────── + + +class AnthropicMiddleware(BaseToolMiddleware): + """Drop-in middleware for Anthropic Messages tool use. + + Example:: + + kernel = Kernel(registry=registry, ...) + mw = AnthropicMiddleware(kernel, principal) + tools = mw.get_tools() + # ... pass tools to the Anthropic client ... + tool_results = await mw.handle_tool_uses(message.content) + """ + + vendor: ClassVar[Literal["openai", "anthropic"]] = "anthropic" + + def __init__( + self, + kernel: Kernel, + principal: Principal, + *, + response_mode: ResponseMode = "summary", + default_cache_control: dict[str, Any] | None = None, + ) -> None: + super().__init__(kernel, principal, response_mode=response_mode) + self._default_cache_control = ( + dict(default_cache_control) if default_cache_control is not None else None + ) + + # ── Public API ───────────────────────────────────────────────────────── + + def get_tools(self) -> list[dict[str, Any]]: + """Return every registered capability as an Anthropic tool definition.""" + return capabilities_to_tools( + self._list_capabilities(), + default_cache_control=self._default_cache_control, + ) + + async def handle_tool_uses( + self, + content_blocks: list[dict[str, Any]], + *, + justification: str = "", + ) -> list[dict[str, Any]]: + """Process every ``tool_use`` block in *content_blocks* through the kernel. + + Args: + content_blocks: An assistant message's ``content`` list. Non- + ``tool_use`` blocks (text, etc.) are passed over so the caller + can hand in raw ``message.content`` directly. + justification: Justification applied to every call in the batch. + Individual calls may override by including + ``"_justification": "..."`` in their ``input``. + + Returns: + One ``tool_result`` content block per processed ``tool_use``, in + input order. Errors are returned as ``tool_result`` blocks with + ``is_error: true`` rather than raised. + """ + results: list[dict[str, Any]] = [] + for block in content_blocks: + if block.get("type") != "tool_use": + continue + tool_use_id = str(block.get("id", "")) + try: + request = tool_use_to_request(block) + except AdapterParseError as exc: + # The capability_id is "(unresolved)" because the parse failed + # before we could extract one — angle-bracket sentinels like + # "" read as HTML/placeholder text to some LLMs. + results.append( + format_result( + error_to_payload(capability_id="(unresolved)", error=str(exc)), + tool_use_id=tool_use_id, + is_error=True, + ) + ) + continue + + raw_input = block.get("input") or {} + # Defensive copy so hook mutation does not leak into the caller's data. + args = dict(raw_input) if isinstance(raw_input, dict) else {} + + prepared = PreparedCall( + call_id=tool_use_id, + capability_id=request.capability_id, + args=args, + ) + event = await self._dispatch_one(prepared, batch_justification=justification) + results.append(self._format_event(event)) + return results + + # ── Internal — vendor-shape result envelope ──────────────────────────── + + def _format_event(self, event: ToolResultEvent) -> dict[str, Any]: + if event.error is not None: + payload = error_to_payload(capability_id=event.capability_id, error=event.error) + return format_result(payload, tool_use_id=event.call_id, is_error=True) + assert event.frame is not None + payload = frame_to_payload(event.frame) + return format_result(payload, tool_use_id=event.call_id, is_error=False) + + +__all__ = [ + "AnthropicMiddleware", + "capabilities_to_tools", + "format_result", + "tool_use_to_request", +] diff --git a/src/agent_kernel/adapters/openai.py b/src/agent_kernel/adapters/openai.py new file mode 100644 index 0000000..8608b0e --- /dev/null +++ b/src/agent_kernel/adapters/openai.py @@ -0,0 +1,358 @@ +"""OpenAI tool-format adapter and middleware. + +Supports both OpenAI tool-shape conventions: + +- **Responses API** (default) — flat ``{"type": "function", "name", ...}`` tool + definitions, ``function_call`` request items, ``function_call_output`` result + items keyed by ``call_id``. +- **Chat Completions API** — nested ``{"type": "function", "function": {...}}`` + tool definitions, ``tool_calls[].function.arguments`` requests, ``{"role": + "tool", "tool_call_id", "content"}`` result messages. + +Tool-call shape is auto-detected on input regardless of the configured output +format. + +No runtime dependency on the ``openai`` package — every public function takes +and returns plain dicts. Pydantic (a kernel runtime dep) handles schema +generation and argument validation. +""" + +from __future__ import annotations + +import json +import logging +import warnings +from typing import TYPE_CHECKING, Any, ClassVar, Literal + +from ..errors import AdapterParseError +from ..models import Capability, CapabilityRequest, ResponseMode +from ._base import ( + BaseToolMiddleware, + PreparedCall, + build_input_schema, + error_to_payload, + frame_to_payload, + make_namespace_safe_name, + normalize_for_openai_strict, + restore_namespace, +) + +if TYPE_CHECKING: + from ..kernel import Kernel + from ..models import Principal + from ._base import ToolResultEvent + +logger = logging.getLogger(__name__) + +OpenAIToolFormat = Literal["responses", "chat_completions"] +"""Supported OpenAI tool/output shapes. + +``responses`` matches the Responses API; ``chat_completions`` matches the +Chat Completions API. See module docstring for the per-format differences. +""" + +_DEFAULT_FORMAT: OpenAIToolFormat = "responses" + + +# ── Schema conversion ───────────────────────────────────────────────────────── + + +def capabilities_to_tools( + capabilities: list[Capability], + *, + format: OpenAIToolFormat = _DEFAULT_FORMAT, +) -> list[dict[str, Any]]: + """Convert :class:`Capability` objects to OpenAI tool definitions. + + Args: + capabilities: Capabilities to expose as tools. + format: ``"responses"`` (default) emits flat Responses-API tool + definitions; ``"chat_completions"`` emits nested Chat Completions + tool definitions. + + Returns: + Vendor-shaped tool definition dicts ready to pass to the OpenAI SDK. + """ + return [_capability_to_tool(cap, format=format) for cap in capabilities] + + +def _capability_to_tool(capability: Capability, *, format: OpenAIToolFormat) -> dict[str, Any]: + description = _describe(capability) + parameters = build_input_schema(capability) + name = make_namespace_safe_name(capability.capability_id) + + strict = bool(capability.tool_hints and capability.tool_hints.strict) + if strict: + try: + parameters = normalize_for_openai_strict(parameters) + except Exception as exc: # noqa: BLE001 — fall back to non-strict + warnings.warn( + f"OpenAI strict-mode normalisation failed for '{capability.capability_id}'" + f": {exc}. Emitting tool definition without strict.", + stacklevel=2, + ) + strict = False + + if format == "chat_completions": + function: dict[str, Any] = { + "name": name, + "description": description, + "parameters": parameters, + } + if strict: + function["strict"] = True + return {"type": "function", "function": function} + + # Responses API: flat shape. + tool: dict[str, Any] = { + "type": "function", + "name": name, + "description": description, + "parameters": parameters, + } + if strict: + tool["strict"] = True + return tool + + +def _describe(capability: Capability) -> str: + """Build the user-facing description with safety/sensitivity context. + + Surfacing ``safety_class`` lets the LLM make better tool-choice decisions + (e.g. avoid DESTRUCTIVE tools when a READ would suffice). + """ + parts = [capability.description] + parts.append(f"[safety={capability.safety_class.value}]") + if capability.sensitivity.value != "NONE": + parts.append(f"[sensitivity={capability.sensitivity.value}]") + return " ".join(parts) + + +# ── Tool call → CapabilityRequest ───────────────────────────────────────────── + + +def tool_call_to_request(tool_call: dict[str, Any]) -> CapabilityRequest: + """Convert an OpenAI tool call dict to a :class:`CapabilityRequest`. + + Auto-detects the input shape: + + - **Chat Completions:** ``{"id": "call_x", "type": "function", + "function": {"name": "...", "arguments": ""}}`` + - **Responses:** ``{"type": "function_call", "call_id": "fc_x", + "name": "...", "arguments": ""}`` + + The ``arguments`` field is always a JSON-encoded string per the OpenAI + spec; this function parses it. + + Raises: + AdapterParseError: If the dict shape isn't recognisable as either + format, or if ``arguments`` is not valid JSON. + """ + name, _ = _extract_name_and_call_id(tool_call) + # Force-parse arguments so callers get the JSON-decode error here, not + # later when the middleware tries to invoke. The parsed value is dropped: + # ``CapabilityRequest`` carries the capability_id + goal only; args are + # threaded through ``handle_tool_calls`` separately. + _parse_arguments(tool_call.get("arguments"), tool_call) + capability_id = restore_namespace(name) + return CapabilityRequest( + capability_id=capability_id, + goal="adapter:openai", + constraints={}, + ) + + +def _extract_name_and_call_id(tool_call: dict[str, Any]) -> tuple[str, str]: + """Return ``(function_name, call_id)`` regardless of input format.""" + fn = tool_call.get("function") + if isinstance(fn, dict): + # Chat Completions: nested function.{name, arguments}, id at top level. + name = fn.get("name") + call_id = tool_call.get("id", "") + else: + # Responses API: flat name/call_id/arguments at top level. + name = tool_call.get("name") + call_id = tool_call.get("call_id", "") or tool_call.get("id", "") + if not isinstance(name, str) or not name: + raise AdapterParseError( + "OpenAI tool_call is missing a function name. Expected either " + "'function.name' (Chat Completions) or 'name' (Responses API)." + ) + return name, str(call_id) + + +def _parse_arguments(raw: Any, tool_call: dict[str, Any]) -> dict[str, Any]: + """Parse the JSON-encoded ``arguments`` field, with format fallback.""" + if raw is None: + fn = tool_call.get("function") + if isinstance(fn, dict): + raw = fn.get("arguments") + if raw is None or raw == "": + return {} + if isinstance(raw, dict): + # Some OpenAI clients pre-parse arguments; accept that shape too. + return dict(raw) + if not isinstance(raw, str): + raise AdapterParseError( + f"OpenAI tool_call 'arguments' must be a JSON string or dict, got {type(raw).__name__}." + ) + try: + parsed = json.loads(raw) + except json.JSONDecodeError as exc: + raise AdapterParseError(f"OpenAI tool_call 'arguments' is not valid JSON: {exc}") from exc + if not isinstance(parsed, dict): + raise AdapterParseError( + "OpenAI tool_call 'arguments' must decode to a JSON object (got " + f"{type(parsed).__name__})." + ) + return parsed + + +# ── Frame / error → OpenAI result ───────────────────────────────────────────── + + +def format_result( + payload: dict[str, Any], + *, + call_id: str, + format: OpenAIToolFormat = _DEFAULT_FORMAT, +) -> dict[str, Any]: + """Wrap a payload dict in an OpenAI tool-result envelope. + + *payload* should already be the canonical kernel-result body produced by + :func:`agent_kernel.adapters._base.frame_to_payload` or + :func:`agent_kernel.adapters._base.error_to_payload`. + """ + body = json.dumps(payload, ensure_ascii=False, sort_keys=True, default=str) + if format == "chat_completions": + return {"role": "tool", "tool_call_id": call_id, "content": body} + return {"type": "function_call_output", "call_id": call_id, "output": body} + + +# ── Middleware ──────────────────────────────────────────────────────────────── + + +class OpenAIMiddleware(BaseToolMiddleware): + """Drop-in middleware for OpenAI Responses / Chat Completions tool use. + + Example:: + + kernel = Kernel(registry=registry, ...) + mw = OpenAIMiddleware(kernel, principal) + tools = mw.get_tools() + # ... pass tools to the OpenAI client ... + outputs = await mw.handle_tool_calls(response.output) + """ + + vendor: ClassVar[Literal["openai", "anthropic"]] = "openai" + + def __init__( + self, + kernel: Kernel, + principal: Principal, + *, + response_mode: ResponseMode = "summary", + format: OpenAIToolFormat = _DEFAULT_FORMAT, + ) -> None: + super().__init__(kernel, principal, response_mode=response_mode) + self._format: OpenAIToolFormat = format + + # ── Public API ───────────────────────────────────────────────────────── + + def get_tools(self) -> list[dict[str, Any]]: + """Return every registered capability as an OpenAI tool definition.""" + return capabilities_to_tools(self._list_capabilities(), format=self._format) + + async def handle_tool_calls( + self, + tool_calls: list[dict[str, Any]], + *, + justification: str = "", + ) -> list[dict[str, Any]]: + """Process a batch of OpenAI tool calls through the kernel pipeline. + + Args: + tool_calls: Either ``response.output`` items from the Responses API + or ``message.tool_calls`` items from the Chat Completions API. + Non-function items (e.g. ``message`` / text items in + Responses-API output) are skipped — the caller stitches the + returned envelopes back into the conversation alongside the + original items. Input shape is auto-detected per call. + justification: Justification applied to every call in the batch. + Individual calls may override by including + ``"_justification": "..."`` in their arguments. + + Returns: + One vendor-shaped result envelope per *processed* tool call, in + input order. Non-tool-call items in the input are skipped (no + envelope is emitted for them) so the caller can interleave + results with the original conversation items. + """ + outputs: list[dict[str, Any]] = [] + for tool_call in tool_calls: + if not _looks_like_tool_call(tool_call): + continue + try: + name, call_id = _extract_name_and_call_id(tool_call) + args = _parse_arguments(tool_call.get("arguments"), tool_call) + except AdapterParseError as exc: + # Surface parse failures as a tool result so the LLM sees the + # error rather than the agent loop crashing. The capability_id + # is "(unresolved)" because the parse failed before we could + # extract one — angle-bracket sentinels like "" read + # as HTML/placeholder text to some LLMs. + outputs.append( + format_result( + error_to_payload(capability_id="(unresolved)", error=str(exc)), + call_id=str(tool_call.get("id") or tool_call.get("call_id") or ""), + format=self._format, + ) + ) + continue + + prepared = PreparedCall( + call_id=call_id, + capability_id=restore_namespace(name), + args=args, + ) + event = await self._dispatch_one(prepared, batch_justification=justification) + outputs.append(self._format_event(event)) + return outputs + + # ── Internal — vendor-shape result envelope ──────────────────────────── + + def _format_event(self, event: ToolResultEvent) -> dict[str, Any]: + if event.error is not None: + payload = error_to_payload(capability_id=event.capability_id, error=event.error) + else: + # frame is guaranteed non-None when error is None. + assert event.frame is not None + payload = frame_to_payload(event.frame) + return format_result(payload, call_id=event.call_id, format=self._format) + + +def _looks_like_tool_call(item: dict[str, Any]) -> bool: + """Detect whether an item is an OpenAI function/tool call. + + Filters Responses-API output items that aren't function calls (e.g. + text/message items) so the caller can pass ``response.output`` directly. + """ + item_type = item.get("type") + if item_type == "function_call": + return True + if item_type == "function": + # Chat Completions style: tools list element. Function calls in + # response messages also use type == "function". + return True + # Chat Completions sometimes omits "type" on tool_calls list entries + # (depending on SDK version); fall back to detecting the nested function. + return isinstance(item.get("function"), dict) + + +__all__ = [ + "OpenAIMiddleware", + "OpenAIToolFormat", + "capabilities_to_tools", + "format_result", + "tool_call_to_request", +] diff --git a/src/agent_kernel/errors.py b/src/agent_kernel/errors.py index 0b4d15f..3b5e342 100644 --- a/src/agent_kernel/errors.py +++ b/src/agent_kernel/errors.py @@ -49,6 +49,25 @@ class FirewallError(AgentKernelError): """Raised when the context firewall cannot transform a raw result.""" +# ── Adapter errors ──────────────────────────────────────────────────────────── + + +class AdapterParseError(AgentKernelError): + """Raised when an LLM tool-format adapter cannot parse vendor input. + + Covers two adapter-side failure modes: + + - Malformed tool-call shapes: missing fields, non-JSON ``arguments``, wrong + types (e.g. ``arguments`` is an int). + - Capability-ID validation: e.g. capability IDs that contain the OpenAI + namespace separator (``__``) cannot be round-tripped unambiguously and + are rejected at tool-emit time. + + Callers can catch this to distinguish adapter parse / validation failures + from kernel-side errors (:class:`PolicyDenied`, :class:`DriverError`). + """ + + # ── Registry / lookup errors ────────────────────────────────────────────────── diff --git a/src/agent_kernel/kernel.py b/src/agent_kernel/kernel.py index 2611471..220853d 100644 --- a/src/agent_kernel/kernel.py +++ b/src/agent_kernel/kernel.py @@ -14,6 +14,7 @@ from .handles import HandleStore from .models import ( ActionTrace, + Capability, CapabilityGrant, CapabilityRequest, DenialExplanation, @@ -83,6 +84,15 @@ def register_driver(self, driver: Driver) -> None: # ── Public API ───────────────────────────────────────────────────────────── + def list_capabilities(self) -> list[Capability]: + """Return every capability registered with the kernel. + + Convenience accessor used by LLM adapters that need to enumerate the + full registry (e.g. ``OpenAIMiddleware.get_tools()``) without reaching + into private state. Capabilities are returned in registration order. + """ + return self._registry.list_all() + def request_capabilities( self, goal: str, diff --git a/src/agent_kernel/models.py b/src/agent_kernel/models.py index 5b7c262..4d58c2e 100644 --- a/src/agent_kernel/models.py +++ b/src/agent_kernel/models.py @@ -13,6 +13,8 @@ from .enums import SafetyClass, SensitivityTag if TYPE_CHECKING: + from pydantic import BaseModel + from .tokens import CapabilityToken ResponseMode = Literal["summary", "table", "handle_only", "raw"] @@ -32,6 +34,32 @@ class ImplementationRef: """Operation name understood by the driver (e.g. ``"list_invoices"``).""" +@dataclass(slots=True) +class ToolHints: + """Vendor-specific tool-definition hints for LLM adapters. + + Consumed by ``agent_kernel.adapters`` when emitting tool schemas. + Engines that don't recognise a hint silently ignore it; setting a hint + never changes how the kernel itself behaves. + """ + + cache_control: dict[str, Any] | None = None + """Anthropic prompt-cache control block (e.g. ``{"type": "ephemeral"}``). + + Forwarded verbatim to the Anthropic tool definition. Ignored by other adapters. + """ + + strict: bool = False + """When ``True``, OpenAI tool definitions are emitted with ``strict: true``. + + The capability's ``parameters_model`` (or ``parameters_schema``) must produce a + JSON Schema that satisfies OpenAI's strict-mode rules (every property required, + ``additionalProperties: false`` on all objects). The adapter normalises objects + where possible and falls back to non-strict with a warning if normalisation fails. + Ignored by other adapters. + """ + + @dataclass(slots=True) class Capability: """A task-shaped unit of work that can be authorized and executed.""" @@ -52,7 +80,12 @@ class Capability: """Optional sensitivity tag.""" allowed_fields: list[str] = field(default_factory=list) - """If non-empty, only these fields are returned unless the caller has ``pii_reader``.""" + """If non-empty, only these fields are returned unless the caller has ``pii_reader``. + + Note: this is an **output redaction** control consumed by the firewall — it does + not describe the capability's input parameters. For input schemas use + :attr:`parameters_model` or :attr:`parameters_schema`. + """ tags: list[str] = field(default_factory=list) """Arbitrary keyword tags used for capability matching.""" @@ -60,6 +93,30 @@ class Capability: impl: ImplementationRef | None = None """Optional pointer to the implementation.""" + parameters_model: type[BaseModel] | None = None + """Optional pydantic model describing the capability's input parameters. + + When present, LLM adapters generate the tool's JSON Schema from + ``parameters_model.model_json_schema()`` and validate incoming tool-call + arguments against the model before invocation. Takes precedence over + :attr:`parameters_schema`. + """ + + parameters_schema: dict[str, Any] | None = None + """Optional raw JSON Schema for the capability's input parameters. + + Used by LLM adapters as a fallback schema source when no + :attr:`parameters_model` is supplied. Forwarded to the vendor tool definition + verbatim; the adapter does not validate incoming arguments against it (use + :attr:`parameters_model` for validation). + """ + + tool_hints: ToolHints | None = None + """Vendor-specific hints consumed by LLM adapters (e.g. Anthropic + ``cache_control``, OpenAI ``strict`` mode). Has no effect on kernel routing + or policy. See :class:`ToolHints`. + """ + # ── Request / Grant ─────────────────────────────────────────────────────────── diff --git a/tests/test_adapters.py b/tests/test_adapters.py new file mode 100644 index 0000000..1bf90f6 --- /dev/null +++ b/tests/test_adapters.py @@ -0,0 +1,1130 @@ +"""Tests for LLM tool-format adapters (OpenAI + Anthropic). + +Adapters are pure dict transforms with a thin async middleware on top. These +tests exercise both pieces without depending on the ``openai`` or +``anthropic`` SDKs. +""" + +from __future__ import annotations + +import json +from typing import Any + +import pytest +from pydantic import BaseModel, Field + +from agent_kernel import ( + AdapterParseError, + AnthropicMiddleware, + Capability, + CapabilityRegistry, + Kernel, + OpenAIMiddleware, + Principal, + SafetyClass, + SensitivityTag, + ToolHints, +) +from agent_kernel.adapters import ( + ToolCallEvent, + ToolResultEvent, +) +from agent_kernel.adapters import ( + anthropic as anthropic_mod, +) +from agent_kernel.adapters import ( + openai as openai_mod, +) +from agent_kernel.adapters._base import ( + build_input_schema, + error_to_payload, + frame_to_payload, + make_namespace_safe_name, + normalize_for_openai_strict, + restore_namespace, + validate_input, +) +from agent_kernel.models import CapabilityRequest + +# ── Helpers ─────────────────────────────────────────────────────────────────── + + +class _InvoiceArgs(BaseModel): + """Schema model used to validate billing.list_invoices arguments.""" + + operation: str = Field(default="list_invoices") + customer_id: str + limit: int = Field(default=10, ge=1, le=100) + + +def _cap_with_model(cap_id: str = "billing.list_invoices") -> Capability: + return Capability( + capability_id=cap_id, + name="List Invoices", + description="List invoices for a customer", + safety_class=SafetyClass.READ, + sensitivity=SensitivityTag.PII, + parameters_model=_InvoiceArgs, + ) + + +# ── Capability model extensions ─────────────────────────────────────────────── + + +def test_capability_defaults_preserve_backward_compat() -> None: + """New fields all default to None — existing constructors keep working.""" + cap = Capability( + capability_id="x.y", + name="X", + description="d", + safety_class=SafetyClass.READ, + ) + assert cap.parameters_model is None + assert cap.parameters_schema is None + assert cap.tool_hints is None + + +def test_tool_hints_dataclass_defaults() -> None: + hints = ToolHints() + assert hints.cache_control is None + assert hints.strict is False + + +# ── Schema helpers (_base) ──────────────────────────────────────────────────── + + +def test_build_input_schema_uses_parameters_model() -> None: + cap = _cap_with_model() + schema = build_input_schema(cap) + assert schema["type"] == "object" + assert set(schema["properties"].keys()) == {"operation", "customer_id", "limit"} + # Pydantic emits required for fields without defaults. + assert schema["required"] == ["customer_id"] + + +def test_build_input_schema_uses_parameters_schema_fallback() -> None: + raw = {"type": "object", "properties": {"x": {"type": "string"}}, "required": ["x"]} + cap = Capability( + capability_id="x.y", + name="X", + description="d", + safety_class=SafetyClass.READ, + parameters_schema=raw, + ) + schema = build_input_schema(cap) + assert schema == raw + # build_input_schema copies — mutating the result must not bleed into the capability. + schema["properties"]["x"]["type"] = "integer" + assert cap.parameters_schema is not None + assert cap.parameters_schema["properties"]["x"]["type"] == "string" + + +def test_build_input_schema_permissive_default() -> None: + cap = Capability( + capability_id="x.y", + name="X", + description="d", + safety_class=SafetyClass.READ, + ) + schema = build_input_schema(cap) + assert schema == {"type": "object", "additionalProperties": True} + + +def test_normalize_for_openai_strict_required_and_no_additional() -> None: + schema = { + "type": "object", + "properties": { + "a": {"type": "string"}, + "b": { + "type": "object", + "properties": {"c": {"type": "integer"}}, + }, + }, + } + out = normalize_for_openai_strict(schema) + assert out["required"] == ["a", "b"] + assert out["additionalProperties"] is False + # Recursive: nested object also gets the treatment. + assert out["properties"]["b"]["required"] == ["c"] + assert out["properties"]["b"]["additionalProperties"] is False + # Original is untouched. + assert "required" not in schema + assert "additionalProperties" not in schema + + +def test_validate_input_passes_through_without_model() -> None: + cap = Capability( + capability_id="x.y", + name="X", + description="d", + safety_class=SafetyClass.READ, + ) + args = {"anything": "goes"} + assert validate_input(cap, args) == args + + +def test_validate_input_with_model_coerces_and_returns_dict() -> None: + cap = _cap_with_model() + out = validate_input(cap, {"customer_id": "c-1", "limit": "5"}) + assert out["customer_id"] == "c-1" + assert out["limit"] == 5 # pydantic coerced str → int + assert out["operation"] == "list_invoices" + + +def test_validate_input_with_model_raises_on_bad_args() -> None: + from pydantic import ValidationError + + cap = _cap_with_model() + with pytest.raises(ValidationError): + validate_input(cap, {"limit": 5}) # missing customer_id + + +# ── Namespace helpers ───────────────────────────────────────────────────────── + + +def test_namespace_round_trip() -> None: + original = "billing.list_invoices" + safe = make_namespace_safe_name(original) + assert safe == "billing__list_invoices" + assert restore_namespace(safe) == original + + +def test_namespace_preserves_single_underscores() -> None: + """Capabilities with underscores in segments must round-trip unambiguously.""" + original = "billing.list_invoices_v2" + assert restore_namespace(make_namespace_safe_name(original)) == original + + +def test_namespace_rejects_capability_id_with_reserved_separator() -> None: + """``__`` in a capability_id collides with the OpenAI namespace separator. + + ``"a__b"`` and ``"a.b"`` would both map to OpenAI tool name ``"a__b"`` — + a silent collision. ``make_namespace_safe_name`` rejects the input rather + than producing a colliding tool name. + """ + with pytest.raises(AdapterParseError, match="reserved namespace separator"): + make_namespace_safe_name("a__b") + + +def test_namespace_collision_surfaces_via_capabilities_to_tools() -> None: + """An invalid capability_id surfaces at adapter-emit time. + + The validation lives in ``make_namespace_safe_name``; callers exercising + the public OpenAI schema-conversion function see the same error. + """ + cap = Capability( + capability_id="a__b", + name="Bad", + description="d", + safety_class=SafetyClass.READ, + ) + with pytest.raises(AdapterParseError, match="reserved namespace separator"): + openai_mod.capabilities_to_tools([cap]) + + +# ── Payload helpers ─────────────────────────────────────────────────────────── + + +def test_error_to_payload_shape() -> None: + payload = error_to_payload(capability_id="x.y", error="boom") + assert payload == {"error": True, "capability_id": "x.y", "message": "boom"} + + +# ── OpenAI: schema conversion ───────────────────────────────────────────────── + + +def test_openai_capabilities_to_tools_responses_format() -> None: + cap = _cap_with_model() + tools = openai_mod.capabilities_to_tools([cap]) + assert len(tools) == 1 + tool = tools[0] + assert tool["type"] == "function" + assert tool["name"] == "billing__list_invoices" + assert "[safety=READ]" in tool["description"] + assert "[sensitivity=PII]" in tool["description"] + assert tool["parameters"]["type"] == "object" + # No nested function key in Responses-API shape. + assert "function" not in tool + + +def test_openai_capabilities_to_tools_chat_completions_format() -> None: + cap = _cap_with_model() + tools = openai_mod.capabilities_to_tools([cap], format="chat_completions") + assert len(tools) == 1 + tool = tools[0] + assert tool["type"] == "function" + assert "function" in tool + assert tool["function"]["name"] == "billing__list_invoices" + assert tool["function"]["parameters"]["type"] == "object" + # No flat name in Chat-Completions shape. + assert "name" not in tool + + +def test_openai_strict_mode_emits_strict_flag_and_normalises_schema() -> None: + cap = Capability( + capability_id="billing.update_invoice", + name="Update", + description="d", + safety_class=SafetyClass.WRITE, + parameters_model=_InvoiceArgs, + tool_hints=ToolHints(strict=True), + ) + tool = openai_mod.capabilities_to_tools([cap])[0] + assert tool["strict"] is True + # Strict normalisation forces every property required and additionalProperties=false. + assert tool["parameters"]["additionalProperties"] is False + assert set(tool["parameters"]["required"]) == {"operation", "customer_id", "limit"} + + +def test_openai_strict_with_optional_field_preserves_nullable() -> None: + """Documented escape hatch: ``Optional[T] = None`` survives strict normalisation. + + OpenAI strict mode lists every property in ``required`` (the adapter enforces + this), but accepts ``null`` for fields declared as nullable. Pydantic emits + ``Optional`` fields as ``anyOf: [..., {"type": "null"}]``; the normaliser + must preserve that shape so the LLM can effectively "omit" the field by + passing ``null``. + """ + + class WithOptional(BaseModel): + name: str + suffix: str | None = None + + cap = Capability( + capability_id="x.y", + name="X", + description="d", + safety_class=SafetyClass.READ, + parameters_model=WithOptional, + tool_hints=ToolHints(strict=True), + ) + tool = openai_mod.capabilities_to_tools([cap])[0] + assert tool["strict"] is True + # All properties land in required, including the Optional one. + assert set(tool["parameters"]["required"]) == {"name", "suffix"} + # The Optional field still advertises null as a valid value via anyOf. + suffix_schema = tool["parameters"]["properties"]["suffix"] + any_of = suffix_schema.get("anyOf") or [] + assert any(branch.get("type") == "null" for branch in any_of), ( + f"Expected 'null' branch in anyOf for the Optional field; got {suffix_schema!r}" + ) + + +def test_openai_description_omits_sensitivity_when_none() -> None: + cap = Capability( + capability_id="x.y", + name="X", + description="d", + safety_class=SafetyClass.READ, + ) + tool = openai_mod.capabilities_to_tools([cap])[0] + assert "[safety=READ]" in tool["description"] + assert "[sensitivity=" not in tool["description"] + + +# ── OpenAI: tool_call → CapabilityRequest ───────────────────────────────────── + + +def test_openai_tool_call_to_request_chat_completions_shape() -> None: + tool_call = { + "id": "call_abc", + "type": "function", + "function": { + "name": "billing__list_invoices", + "arguments": json.dumps({"customer_id": "c-1"}), + }, + } + req = openai_mod.tool_call_to_request(tool_call) + assert isinstance(req, CapabilityRequest) + assert req.capability_id == "billing.list_invoices" + + +def test_openai_tool_call_to_request_responses_shape() -> None: + tool_call = { + "type": "function_call", + "call_id": "fc_xyz", + "name": "billing__list_invoices", + "arguments": json.dumps({"customer_id": "c-1"}), + } + req = openai_mod.tool_call_to_request(tool_call) + assert req.capability_id == "billing.list_invoices" + + +def test_openai_tool_call_to_request_raises_on_invalid_json() -> None: + tool_call = { + "type": "function_call", + "call_id": "fc_xyz", + "name": "billing__list_invoices", + "arguments": "{not valid", + } + with pytest.raises(AdapterParseError, match="not valid JSON"): + openai_mod.tool_call_to_request(tool_call) + + +def test_openai_tool_call_to_request_raises_on_missing_name() -> None: + with pytest.raises(AdapterParseError, match="missing a function name"): + openai_mod.tool_call_to_request({"id": "x", "type": "function", "function": {}}) + + +# ── OpenAI middleware: end-to-end ───────────────────────────────────────────── + + +def test_openai_get_tools_lists_all_registered_capabilities( + kernel: Kernel, reader_principal: Principal +) -> None: + mw = OpenAIMiddleware(kernel, reader_principal) + tools = mw.get_tools() + names = {t["name"] for t in tools} + assert "billing__list_invoices" in names + assert "billing__update_invoice" in names # WRITE capability still listed + + +@pytest.mark.asyncio +async def test_openai_handle_tool_calls_success_flow( + kernel: Kernel, reader_principal: Principal +) -> None: + mw = OpenAIMiddleware(kernel, reader_principal) + tool_calls = [ + { + "type": "function_call", + "call_id": "fc_1", + "name": "billing__list_invoices", + "arguments": json.dumps({"operation": "billing.list_invoices"}), + } + ] + outputs = await mw.handle_tool_calls(tool_calls) + assert len(outputs) == 1 + out = outputs[0] + assert out["type"] == "function_call_output" + assert out["call_id"] == "fc_1" + payload = json.loads(out["output"]) + assert payload["capability_id"] == "billing.list_invoices" + assert payload["response_mode"] == "summary" + assert "error" not in payload + + +@pytest.mark.asyncio +async def test_openai_handle_tool_calls_policy_denied_surfaces_as_error( + kernel: Kernel, reader_principal: Principal +) -> None: + """A WRITE call by a reader becomes a tool-result error, not a raised exception.""" + mw = OpenAIMiddleware(kernel, reader_principal) + tool_calls = [ + { + "type": "function_call", + "call_id": "fc_deny", + "name": "billing__update_invoice", + "arguments": json.dumps({}), + } + ] + outputs = await mw.handle_tool_calls(tool_calls, justification="long enough justification") + payload = json.loads(outputs[0]["output"]) + assert payload["error"] is True + assert "Policy denied" in payload["message"] + + +@pytest.mark.asyncio +async def test_openai_handle_tool_calls_unknown_capability_surfaces_as_error( + kernel: Kernel, reader_principal: Principal +) -> None: + mw = OpenAIMiddleware(kernel, reader_principal) + tool_calls = [ + { + "type": "function_call", + "call_id": "fc_unknown", + "name": "nonexistent__capability", + "arguments": "{}", + } + ] + outputs = await mw.handle_tool_calls(tool_calls) + payload = json.loads(outputs[0]["output"]) + assert payload["error"] is True + assert "not registered" in payload["message"] + + +@pytest.mark.asyncio +async def test_openai_handle_tool_calls_invalid_json_arguments_surfaces_as_error( + kernel: Kernel, reader_principal: Principal +) -> None: + mw = OpenAIMiddleware(kernel, reader_principal) + tool_calls = [ + { + "type": "function_call", + "call_id": "fc_bad", + "name": "billing__list_invoices", + "arguments": "{not valid", + } + ] + outputs = await mw.handle_tool_calls(tool_calls) + payload = json.loads(outputs[0]["output"]) + assert payload["error"] is True + assert "not valid JSON" in payload["message"] + + +@pytest.mark.asyncio +async def test_openai_handle_tool_calls_skips_non_function_items( + kernel: Kernel, reader_principal: Principal +) -> None: + """Non-function items in response.output are silently skipped.""" + mw = OpenAIMiddleware(kernel, reader_principal) + output = await mw.handle_tool_calls( + [ + {"type": "message", "content": "thinking..."}, + { + "type": "function_call", + "call_id": "fc_1", + "name": "billing__list_invoices", + "arguments": "{}", + }, + ] + ) + assert len(output) == 1 + assert output[0]["call_id"] == "fc_1" + + +@pytest.mark.asyncio +async def test_openai_chat_completions_output_shape( + kernel: Kernel, reader_principal: Principal +) -> None: + mw = OpenAIMiddleware(kernel, reader_principal, format="chat_completions") + outputs = await mw.handle_tool_calls( + [ + { + "id": "call_chat", + "type": "function", + "function": { + "name": "billing__list_invoices", + "arguments": "{}", + }, + } + ] + ) + assert outputs[0]["role"] == "tool" + assert outputs[0]["tool_call_id"] == "call_chat" + assert "content" in outputs[0] + + +# ── OpenAI middleware: hooks ────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_openai_hooks_fire_in_registration_order( + kernel: Kernel, reader_principal: Principal +) -> None: + seen: list[str] = [] + mw = OpenAIMiddleware(kernel, reader_principal) + mw.intercept_tool_call(lambda e: seen.append(f"pre1:{e.capability_id}")) + mw.intercept_tool_call(lambda e: seen.append(f"pre2:{e.capability_id}")) + mw.intercept_tool_result(lambda e: seen.append(f"post1:{e.capability_id}")) + mw.intercept_tool_result(lambda e: seen.append(f"post2:{e.capability_id}")) + + await mw.handle_tool_calls( + [ + { + "type": "function_call", + "call_id": "fc_1", + "name": "billing__list_invoices", + "arguments": "{}", + } + ] + ) + assert seen == [ + "pre1:billing.list_invoices", + "pre2:billing.list_invoices", + "post1:billing.list_invoices", + "post2:billing.list_invoices", + ] + + +@pytest.mark.asyncio +async def test_openai_hooks_support_async_callables( + kernel: Kernel, reader_principal: Principal +) -> None: + """Async hooks are awaited; sync and async can be mixed in the same registration.""" + seen: list[str] = [] + + async def async_pre(event: ToolCallEvent) -> None: + seen.append(f"async_pre:{event.capability_id}") + + def sync_pre(event: ToolCallEvent) -> None: + seen.append(f"sync_pre:{event.capability_id}") + + mw = OpenAIMiddleware(kernel, reader_principal) + mw.intercept_tool_call(async_pre) + mw.intercept_tool_call(sync_pre) + await mw.handle_tool_calls( + [ + { + "type": "function_call", + "call_id": "fc_1", + "name": "billing__list_invoices", + "arguments": "{}", + } + ] + ) + assert seen == ["async_pre:billing.list_invoices", "sync_pre:billing.list_invoices"] + + +@pytest.mark.asyncio +async def test_openai_pre_hook_can_abort(kernel: Kernel, reader_principal: Principal) -> None: + """Pre-hook setting aborted=True short-circuits without invoking the kernel.""" + + def gate(event: ToolCallEvent) -> None: + event.aborted = True + event.abort_reason = "manual approval required" + + mw = OpenAIMiddleware(kernel, reader_principal) + mw.intercept_tool_call(gate) + outputs = await mw.handle_tool_calls( + [ + { + "type": "function_call", + "call_id": "fc_1", + "name": "billing__list_invoices", + "arguments": "{}", + } + ] + ) + payload = json.loads(outputs[0]["output"]) + assert payload["error"] is True + assert "Aborted by pre-invocation hook" in payload["message"] + assert "manual approval required" in payload["message"] + + +@pytest.mark.asyncio +async def test_openai_pre_hook_injects_justification_for_write( + kernel: Kernel, reader_principal: Principal, writer_principal: Principal +) -> None: + """A pre-hook can inject a justification so a WRITE call satisfies policy.""" + + def inject(event: ToolCallEvent) -> None: + event.justification = "approved by reviewer 12345 with sufficient length" + + mw = OpenAIMiddleware(kernel, writer_principal) + mw.intercept_tool_call(inject) + outputs = await mw.handle_tool_calls( + [ + { + "type": "function_call", + "call_id": "fc_1", + "name": "billing__update_invoice", + "arguments": json.dumps({"operation": "billing.update_invoice"}), + } + ] + ) + payload = json.loads(outputs[0]["output"]) + assert payload.get("error") is None + assert payload["capability_id"] == "billing.update_invoice" + + +@pytest.mark.asyncio +async def test_openai_per_call_justification_override( + kernel: Kernel, writer_principal: Principal +) -> None: + """Per-call _justification in args overrides the batch justification.""" + mw = OpenAIMiddleware(kernel, writer_principal) + outputs = await mw.handle_tool_calls( + [ + { + "type": "function_call", + "call_id": "fc_1", + "name": "billing__update_invoice", + "arguments": json.dumps( + { + "operation": "billing.update_invoice", + "_justification": "per-call long enough justification", + } + ), + } + ], + justification="short", # would fail if used + ) + payload = json.loads(outputs[0]["output"]) + assert payload.get("error") is None + + +@pytest.mark.asyncio +async def test_openai_post_hook_observes_frame( + kernel: Kernel, reader_principal: Principal +) -> None: + captured: list[ToolResultEvent] = [] + mw = OpenAIMiddleware(kernel, reader_principal) + mw.intercept_tool_result(lambda e: captured.append(e)) + await mw.handle_tool_calls( + [ + { + "type": "function_call", + "call_id": "fc_1", + "name": "billing__list_invoices", + "arguments": "{}", + } + ] + ) + assert len(captured) == 1 + assert captured[0].frame is not None + assert captured[0].error is None + assert captured[0].capability_id == "billing.list_invoices" + + +# ── Anthropic: schema conversion ────────────────────────────────────────────── + + +def test_anthropic_capabilities_to_tools_preserves_dotted_id() -> None: + cap = _cap_with_model() + tools = anthropic_mod.capabilities_to_tools([cap]) + assert len(tools) == 1 + tool = tools[0] + # Anthropic preserves dots — no namespace transformation. + assert tool["name"] == "billing.list_invoices" + assert "input_schema" in tool + assert tool["input_schema"]["type"] == "object" + assert "cache_control" not in tool # no default, no per-cap hint + + +def test_anthropic_capabilities_to_tools_per_capability_cache_control() -> None: + cap = Capability( + capability_id="x.y", + name="X", + description="d", + safety_class=SafetyClass.READ, + tool_hints=ToolHints(cache_control={"type": "ephemeral"}), + ) + tool = anthropic_mod.capabilities_to_tools([cap])[0] + assert tool["cache_control"] == {"type": "ephemeral"} + + +def test_anthropic_capabilities_to_tools_default_cache_control_applied() -> None: + cap = Capability( + capability_id="x.y", + name="X", + description="d", + safety_class=SafetyClass.READ, + ) + tool = anthropic_mod.capabilities_to_tools([cap], default_cache_control={"type": "ephemeral"})[ + 0 + ] + assert tool["cache_control"] == {"type": "ephemeral"} + + +def test_anthropic_capabilities_to_tools_per_capability_overrides_default() -> None: + cap = Capability( + capability_id="x.y", + name="X", + description="d", + safety_class=SafetyClass.READ, + tool_hints=ToolHints(cache_control={"type": "static"}), + ) + tool = anthropic_mod.capabilities_to_tools([cap], default_cache_control={"type": "ephemeral"})[ + 0 + ] + assert tool["cache_control"] == {"type": "static"} + + +# ── Anthropic: tool_use → CapabilityRequest ─────────────────────────────────── + + +def test_anthropic_tool_use_to_request_preserves_dotted_id() -> None: + block = { + "type": "tool_use", + "id": "toolu_xyz", + "name": "billing.list_invoices", + "input": {"customer_id": "c-1"}, + } + req = anthropic_mod.tool_use_to_request(block) + assert req.capability_id == "billing.list_invoices" + + +def test_anthropic_tool_use_to_request_raises_on_missing_name() -> None: + with pytest.raises(AdapterParseError, match="missing a 'name'"): + anthropic_mod.tool_use_to_request({"type": "tool_use", "id": "x", "input": {}}) + + +def test_anthropic_tool_use_to_request_raises_on_non_dict_input() -> None: + with pytest.raises(AdapterParseError, match="must be an object"): + anthropic_mod.tool_use_to_request( + {"type": "tool_use", "id": "x", "name": "n", "input": "string"} + ) + + +# ── Anthropic middleware: end-to-end ────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_anthropic_handle_tool_uses_success_flow( + kernel: Kernel, reader_principal: Principal +) -> None: + mw = AnthropicMiddleware(kernel, reader_principal) + blocks = [ + { + "type": "tool_use", + "id": "toolu_1", + "name": "billing.list_invoices", + "input": {"operation": "billing.list_invoices"}, + } + ] + results = await mw.handle_tool_uses(blocks) + assert len(results) == 1 + r = results[0] + assert r["type"] == "tool_result" + assert r["tool_use_id"] == "toolu_1" + assert r.get("is_error") is None or r["is_error"] is False + text_block = r["content"][0] + assert text_block["type"] == "text" + payload = json.loads(text_block["text"]) + assert payload["capability_id"] == "billing.list_invoices" + + +@pytest.mark.asyncio +async def test_anthropic_handle_tool_uses_skips_text_blocks( + kernel: Kernel, reader_principal: Principal +) -> None: + mw = AnthropicMiddleware(kernel, reader_principal) + results = await mw.handle_tool_uses( + [ + {"type": "text", "text": "thinking..."}, + { + "type": "tool_use", + "id": "toolu_1", + "name": "billing.list_invoices", + "input": {}, + }, + ] + ) + assert len(results) == 1 + assert results[0]["tool_use_id"] == "toolu_1" + + +@pytest.mark.asyncio +async def test_anthropic_handle_tool_uses_policy_denied_is_error_block( + kernel: Kernel, reader_principal: Principal +) -> None: + mw = AnthropicMiddleware(kernel, reader_principal) + results = await mw.handle_tool_uses( + [ + { + "type": "tool_use", + "id": "toolu_deny", + "name": "billing.update_invoice", + "input": {}, + } + ], + justification="long enough justification", + ) + assert results[0]["is_error"] is True + payload = json.loads(results[0]["content"][0]["text"]) + assert payload["error"] is True + assert "Policy denied" in payload["message"] + + +@pytest.mark.asyncio +async def test_anthropic_handle_tool_uses_unknown_capability_is_error_block( + kernel: Kernel, reader_principal: Principal +) -> None: + mw = AnthropicMiddleware(kernel, reader_principal) + results = await mw.handle_tool_uses( + [{"type": "tool_use", "id": "toolu_x", "name": "nope.nada", "input": {}}] + ) + assert results[0]["is_error"] is True + payload = json.loads(results[0]["content"][0]["text"]) + assert "not registered" in payload["message"] + + +# ── Anthropic hooks ─────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_anthropic_hooks_fire_in_registration_order( + kernel: Kernel, reader_principal: Principal +) -> None: + seen: list[str] = [] + mw = AnthropicMiddleware(kernel, reader_principal) + mw.intercept_tool_call(lambda e: seen.append(f"pre:{e.vendor}")) + mw.intercept_tool_result(lambda e: seen.append(f"post:{e.vendor}")) + await mw.handle_tool_uses( + [ + { + "type": "tool_use", + "id": "toolu_1", + "name": "billing.list_invoices", + "input": {}, + } + ] + ) + assert seen == ["pre:anthropic", "post:anthropic"] + + +# ── Shared: validation, frame_to_payload ────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_middleware_input_validation_surfaces_as_tool_error( + reader_principal: Principal, +) -> None: + """When a capability has parameters_model, bad args become a tool-result error.""" + cap = _cap_with_model() + registry = CapabilityRegistry() + registry.register(cap) + + from agent_kernel import HMACTokenProvider, InMemoryDriver, StaticRouter + from agent_kernel.drivers.base import ExecutionContext + + driver = InMemoryDriver(driver_id="memory") + + def echo(ctx: ExecutionContext) -> dict[str, Any]: + return {"echo": ctx.args} + + driver.register_handler("billing.list_invoices", echo) + driver.register_handler("list_invoices", echo) + kernel = Kernel( + registry=registry, + token_provider=HMACTokenProvider(secret="test-secret"), + router=StaticRouter(routes={"billing.list_invoices": ["memory"]}), + ) + kernel.register_driver(driver) + + mw = OpenAIMiddleware(kernel, reader_principal) + outputs = await mw.handle_tool_calls( + [ + { + "type": "function_call", + "call_id": "fc_bad", + "name": "billing__list_invoices", + "arguments": json.dumps({"limit": 5}), # missing required customer_id + } + ] + ) + payload = json.loads(outputs[0]["output"]) + assert payload["error"] is True + assert "validation failed" in payload["message"].lower() + + +@pytest.mark.asyncio +async def test_pre_hook_exception_becomes_tool_error( + kernel: Kernel, reader_principal: Principal +) -> None: + """A pre-hook that raises does not crash the batch — it becomes a tool error.""" + + def bad(event: ToolCallEvent) -> None: + raise RuntimeError("pre-hook explosion") + + mw = OpenAIMiddleware(kernel, reader_principal) + mw.intercept_tool_call(bad) + outputs = await mw.handle_tool_calls( + [ + { + "type": "function_call", + "call_id": "fc_1", + "name": "billing__list_invoices", + "arguments": "{}", + } + ] + ) + payload = json.loads(outputs[0]["output"]) + assert payload["error"] is True + assert "Pre-invocation hook raised" in payload["message"] + assert "pre-hook explosion" in payload["message"] + + +@pytest.mark.asyncio +async def test_post_hook_exception_is_logged_not_raised( + kernel: Kernel, reader_principal: Principal, caplog: pytest.LogCaptureFixture +) -> None: + """A post-hook that raises is logged but never crashes the batch.""" + + def bad(event: ToolResultEvent) -> None: + raise RuntimeError("post-hook explosion") + + mw = OpenAIMiddleware(kernel, reader_principal) + mw.intercept_tool_result(bad) + with caplog.at_level("WARNING", logger="agent_kernel.adapters._base"): + outputs = await mw.handle_tool_calls( + [ + { + "type": "function_call", + "call_id": "fc_1", + "name": "billing__list_invoices", + "arguments": "{}", + } + ] + ) + # The call still succeeded — only the hook failed. + payload = json.loads(outputs[0]["output"]) + assert payload.get("error") is None + assert any("post_hook_failed" in rec.message for rec in caplog.records) + + +@pytest.mark.asyncio +async def test_driver_error_surfaces_as_tool_error(reader_principal: Principal) -> None: + """A DriverError is converted to a tool-result error, not raised.""" + from agent_kernel import HMACTokenProvider, InMemoryDriver, StaticRouter + from agent_kernel.drivers.base import ExecutionContext + from agent_kernel.errors import DriverError + + registry = CapabilityRegistry() + registry.register( + Capability( + capability_id="explode.now", + name="Explode", + description="d", + safety_class=SafetyClass.READ, + ) + ) + + def explode(ctx: ExecutionContext) -> dict[str, Any]: + raise DriverError("driver fell over") + + driver = InMemoryDriver(driver_id="memory") + driver.register_handler("explode.now", explode) + kernel = Kernel( + registry=registry, + token_provider=HMACTokenProvider(secret="test-secret"), + router=StaticRouter(routes={"explode.now": ["memory"]}), + ) + kernel.register_driver(driver) + + mw = OpenAIMiddleware(kernel, reader_principal) + outputs = await mw.handle_tool_calls( + [ + { + "type": "function_call", + "call_id": "fc_1", + "name": "explode__now", + "arguments": "{}", + } + ] + ) + payload = json.loads(outputs[0]["output"]) + assert payload["error"] is True + assert "Driver error" in payload["message"] + + +# ── OpenAI: argument parsing edge cases ─────────────────────────────────────── + + +def test_openai_parse_pre_parsed_dict_arguments() -> None: + """Some clients pre-parse arguments; the adapter accepts dict input too.""" + req = openai_mod.tool_call_to_request( + { + "type": "function_call", + "call_id": "fc_1", + "name": "billing__list_invoices", + "arguments": {"customer_id": "c-1"}, + } + ) + assert req.capability_id == "billing.list_invoices" + + +def test_openai_parse_arguments_array_raises() -> None: + """A JSON array (not object) in arguments is a contract violation.""" + with pytest.raises(AdapterParseError, match="must decode to a JSON object"): + openai_mod.tool_call_to_request( + { + "type": "function_call", + "call_id": "fc_1", + "name": "billing__list_invoices", + "arguments": "[1, 2, 3]", + } + ) + + +def test_openai_parse_arguments_wrong_type_raises() -> None: + """Argument values that are neither string nor dict are rejected.""" + with pytest.raises(AdapterParseError, match="must be a JSON string or dict"): + openai_mod.tool_call_to_request( + { + "type": "function_call", + "call_id": "fc_1", + "name": "billing__list_invoices", + "arguments": 42, + } + ) + + +def test_openai_empty_arguments_string_treated_as_empty_dict() -> None: + req = openai_mod.tool_call_to_request( + { + "type": "function_call", + "call_id": "fc_1", + "name": "billing__list_invoices", + "arguments": "", + } + ) + assert req.capability_id == "billing.list_invoices" + + +def test_openai_strict_with_chat_completions_format() -> None: + """Strict flag lands inside the nested function for Chat-Completions shape.""" + cap = Capability( + capability_id="x.y", + name="X", + description="d", + safety_class=SafetyClass.READ, + parameters_model=_InvoiceArgs, + tool_hints=ToolHints(strict=True), + ) + tool = openai_mod.capabilities_to_tools([cap], format="chat_completions")[0] + assert tool["function"]["strict"] is True + assert tool["function"]["parameters"]["additionalProperties"] is False + + +# ── Anthropic: edge cases ───────────────────────────────────────────────────── + + +def test_anthropic_tool_use_to_request_treats_none_input_as_empty() -> None: + """Anthropic may emit ``input: null`` for zero-argument tools.""" + req = anthropic_mod.tool_use_to_request( + {"type": "tool_use", "id": "x", "name": "billing.summarize_spend", "input": None} + ) + assert req.capability_id == "billing.summarize_spend" + + +@pytest.mark.asyncio +async def test_anthropic_handle_tool_uses_parse_error_surfaces_as_error_block( + kernel: Kernel, reader_principal: Principal +) -> None: + """A malformed tool_use block produces an is_error result, not an exception.""" + mw = AnthropicMiddleware(kernel, reader_principal) + results = await mw.handle_tool_uses( + [ + # name field is missing — tool_use_to_request raises ValueError. + {"type": "tool_use", "id": "toolu_bad", "input": {}}, + ] + ) + assert len(results) == 1 + assert results[0]["is_error"] is True + payload = json.loads(results[0]["content"][0]["text"]) + assert "missing a 'name'" in payload["message"] + + +def test_frame_to_payload_shape(kernel: Kernel) -> None: + """frame_to_payload returns the canonical JSON shape both adapters share.""" + import datetime + + from agent_kernel.models import Frame, Handle + + handle = Handle( + handle_id="h-1", + capability_id="x.y", + created_at=datetime.datetime.now(tz=datetime.timezone.utc), + expires_at=datetime.datetime.now(tz=datetime.timezone.utc), + total_rows=42, + ) + frame = Frame( + action_id="a-1", + capability_id="x.y", + response_mode="summary", + facts=["one fact"], + table_preview=[{"a": 1}], + warnings=["careful"], + handle=handle, + ) + payload = frame_to_payload(frame) + assert payload == { + "action_id": "a-1", + "capability_id": "x.y", + "response_mode": "summary", + "facts": ["one fact"], + "table_preview": [{"a": 1}], + "warnings": ["careful"], + "handle": {"handle_id": "h-1", "total_rows": 42}, + }