diff --git a/CHANGELOG.md b/CHANGELOG.md index 24d3a90..34a8cf8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- Cross-invocation context budget manager (`BudgetManager`) tracks cumulative token usage across + multiple `Kernel.invoke()` calls within a session. When attached to a `Kernel` via the new + `budget_manager` keyword argument, the kernel reserves a budget slice before each invocation + and reconciles actual frame-payload usage afterwards. As the remaining budget shrinks the + requested `response_mode` is auto-escalated to a more aggressive tier (> 50% remaining keeps + the caller's mode; 20–50% downgrades `raw` to `table`; 5–20% floors at `summary`; < 5% forces + `handle_only`). `Kernel.invoke(..., dry_run=True)` now also reports `budget_remaining` and the + escalated `response_mode` when a manager is configured. The `BudgetManager` is optional and + off by default — existing kernels are unchanged. (#44) +- `TokenCounter` protocol and `default_token_counter` (character-based `len(json.dumps(...))//4` + approximation) provide pluggable token counting without runtime dependencies. A new optional + `[tiktoken]` extra is reserved for callers that want to plug in `tiktoken`-based counting. +- `BudgetExhausted(AgentKernelError)` raised by `BudgetManager.allocate()` (and by + `Kernel.invoke()` before driver execution) when the cumulative session budget is fully spent. +- `BudgetConfigError(AgentKernelError)` raised by `BudgetManager` for invalid configuration or + validation failures (non-positive budgets, negative allocate/record/release amounts), replacing + bare `ValueError` so callers can catch budget mistakes via the `AgentKernelError` hierarchy + per `AGENTS.md` ("never raise bare ValueError to callers"). +- New public exports: `BudgetManager`, `BudgetExhausted`, `BudgetConfigError`, `TokenCounter`, + `default_token_counter`, and `Kernel.budget` accessor property. - LLM tool-format adapters and middleware (`agent_kernel.adapters`): `OpenAIMiddleware` (OpenAI Responses API + Chat Completions, auto-detected on input) and `AnthropicMiddleware` (Anthropic Messages with `cache_control` support). Both translate `Capability` objects to vendor tool diff --git a/docs/context_firewall.md b/docs/context_firewall.md index 9d14d99..3b1623a 100644 --- a/docs/context_firewall.md +++ b/docs/context_firewall.md @@ -62,3 +62,59 @@ Summaries are produced deterministically: - **dict** → key list + per-value type/value - **string** → truncated to 500 chars - **other** → repr() truncated to 200 chars + +## Cross-invocation budgets + +The per-invocation `Budgets` above cap a single Frame. A separate +`BudgetManager` tracks cumulative token usage *across* invocations within a +session. It is optional — if you don't attach one, kernel behavior is +unchanged. + +```python +from agent_kernel import BudgetManager, Kernel + +manager = BudgetManager(total_budget=100_000) +kernel = Kernel(registry, budget_manager=manager) +``` + +Per `invoke()` the kernel: + +1. Reserves a slice of the remaining budget (default 4,000 tokens). If the + budget is empty, `BudgetExhausted` is raised before the driver runs. +2. Consults `manager.suggested_mode(requested)` to escalate the requested + `response_mode` to a more aggressive tier as the remaining budget shrinks. +3. After the firewall produces a Frame, counts the actual tokens in the + LLM-facing payload and reconciles them against the reservation. + +Escalation table: + +| Budget remaining | Suggested mode (effective `response_mode`) | +|-----------------:|------------------------------------------------| +| > 50% | Caller's requested mode (no change) | +| 20% – 50% | `table` (when caller requested `raw`) | +| 5% – 20% (≥ 5%) | `summary` (floor — never *relaxes* to `table`) | +| < 5% | `handle_only` | + +Boundaries land in the more-conservative tier — exactly 50% remaining +downgrades `raw` to `table`, exactly 20% floors at `summary`, and only when +remaining drops *below* 5% does `handle_only` take over. + +`Kernel.invoke(..., dry_run=True)` mirrors the escalation and reports +`budget_remaining` in the returned `DryRunResult`, so callers can preview +what their next live invocation would actually return. + +Plug a different token counter (for example, a `tiktoken`-based one) via the +`TokenCounter` protocol: + +```python +import tiktoken # pip install weaver-kernel[tiktoken] +enc = tiktoken.encoding_for_model("gpt-4o") + +def tiktoken_counter(value): + return len(enc.encode(str(value))) + +manager = BudgetManager(total_budget=128_000, token_counter=tiktoken_counter) +``` + +The default counter (`default_token_counter`) is a character-based +`len(json.dumps(value)) // 4` approximation with no extra dependencies. diff --git a/pyproject.toml b/pyproject.toml index 0159973..d6e3588 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ policy = [ "pyyaml>=6.0", "tomli>=2.0; python_version<'3.11'", ] +tiktoken = ["tiktoken>=0.6"] [tool.hatch.build.targets.wheel] packages = ["src/agent_kernel"] diff --git a/src/agent_kernel/__init__.py b/src/agent_kernel/__init__.py index cdef01c..919c2bc 100644 --- a/src/agent_kernel/__init__.py +++ b/src/agent_kernel/__init__.py @@ -19,7 +19,7 @@ Firewall:: - from agent_kernel import Firewall, Budgets + from agent_kernel import Firewall, Budgets, BudgetManager Handles & traces:: @@ -35,6 +35,7 @@ AgentKernelError, TokenExpired, TokenInvalid, TokenScopeError, PolicyDenied, PolicyConfigError, DriverError, FirewallError, + BudgetExhausted, BudgetConfigError, CapabilityNotFound, HandleNotFound, HandleExpired, ) """ @@ -48,6 +49,8 @@ from .errors import ( AdapterParseError, AgentKernelError, + BudgetConfigError, + BudgetExhausted, CapabilityAlreadyRegistered, CapabilityNotFound, DriverError, @@ -61,7 +64,9 @@ TokenRevoked, TokenScopeError, ) +from .firewall.budget_manager import BudgetManager from .firewall.budgets import Budgets +from .firewall.token_counting import TokenCounter, default_token_counter from .firewall.transform import Firewall from .handles import HandleStore from .kernel import Kernel @@ -125,6 +130,8 @@ # errors "AdapterParseError", "AgentKernelError", + "BudgetConfigError", + "BudgetExhausted", "CapabilityAlreadyRegistered", "CapabilityNotFound", "DriverError", @@ -156,8 +163,11 @@ "MCPDriver", "make_billing_driver", # firewall - "Firewall", + "BudgetManager", "Budgets", + "Firewall", + "TokenCounter", + "default_token_counter", # stores "HandleStore", "TraceStore", diff --git a/src/agent_kernel/errors.py b/src/agent_kernel/errors.py index 3b5e342..263e9ed 100644 --- a/src/agent_kernel/errors.py +++ b/src/agent_kernel/errors.py @@ -49,6 +49,27 @@ class FirewallError(AgentKernelError): """Raised when the context firewall cannot transform a raw result.""" +class BudgetExhausted(AgentKernelError): + """Raised when a :class:`~agent_kernel.firewall.budgets.BudgetManager` has + no remaining cross-invocation context budget. + + Distinct from :class:`FirewallError`: this error fires *before* the + firewall transforms data, signalling that the caller has consumed the + entire session-level context budget. The current invocation never runs + the driver. + """ + + +class BudgetConfigError(AgentKernelError): + """Raised when a :class:`~agent_kernel.firewall.budgets.BudgetManager` is + constructed with invalid parameters, or asked to allocate/record/release + a negative amount. + + Used in place of bare :class:`ValueError` so callers can catch budget + configuration mistakes without swallowing unrelated stdlib errors. + """ + + # ── Adapter errors ──────────────────────────────────────────────────────────── diff --git a/src/agent_kernel/firewall/__init__.py b/src/agent_kernel/firewall/__init__.py index 8912822..17a45d4 100644 --- a/src/agent_kernel/firewall/__init__.py +++ b/src/agent_kernel/firewall/__init__.py @@ -1,8 +1,18 @@ """Firewall sub-package exports.""" +from .budget_manager import BudgetManager from .budgets import Budgets from .redaction import redact from .summarize import summarize +from .token_counting import TokenCounter, default_token_counter from .transform import Firewall -__all__ = ["Budgets", "Firewall", "redact", "summarize"] +__all__ = [ + "BudgetManager", + "Budgets", + "Firewall", + "TokenCounter", + "default_token_counter", + "redact", + "summarize", +] diff --git a/src/agent_kernel/firewall/budget_manager.py b/src/agent_kernel/firewall/budget_manager.py new file mode 100644 index 0000000..3e8f19f --- /dev/null +++ b/src/agent_kernel/firewall/budget_manager.py @@ -0,0 +1,275 @@ +"""Cross-invocation session-level budget manager. + +A :class:`BudgetManager` tracks cumulative token usage across multiple +:meth:`~agent_kernel.Kernel.invoke` calls and suggests +:class:`~agent_kernel.models.ResponseMode` escalation as the remaining budget +shrinks. The manager is optional — a :class:`~agent_kernel.Kernel` +constructed without one behaves identically to earlier versions of the +library. + +This module is the implementation of issue #44. +""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass +from typing import Any + +from ..errors import BudgetConfigError, BudgetExhausted +from ..models import ResponseMode +from .token_counting import TokenCounter, default_token_counter + +logger = logging.getLogger(__name__) + + +@dataclass(slots=True) +class _BudgetState: + """Internal mutable state for :class:`BudgetManager`.""" + + total: int + used: int + reserved: int + + +class BudgetManager: + """Tracks cumulative token usage across invocations within a session. + + When attached to a :class:`~agent_kernel.Kernel` via the + ``budget_manager`` constructor parameter, the kernel calls + :meth:`allocate` before every invocation to reserve a slice of the + remaining budget and :meth:`record_usage` after the firewall has + produced a Frame to reconcile actual consumption. Once the remaining + budget shrinks, the kernel consults :meth:`suggested_mode` to escalate + the requested response mode to a more aggressive summarisation tier. + + Concurrency: :meth:`allocate`, :meth:`record_usage`, and :meth:`release` + are serialised behind an internal :class:`asyncio.Lock` so concurrent + invocations from the same kernel see consistent budget state. + + Example:: + + manager = BudgetManager(total_budget=100_000) + kernel = Kernel(registry, budget_manager=manager) + frame = await kernel.invoke(token, principal=p, args={}) # consumes + assert manager.remaining < 100_000 + """ + + __slots__ = ("_state", "_lock", "_counter", "_default_request") + + def __init__( + self, + total_budget: int = 100_000, + *, + token_counter: TokenCounter | None = None, + default_request: int = 4_000, + ) -> None: + """Initialise a :class:`BudgetManager`. + + Args: + total_budget: Total token budget for the session. Must be > 0. + token_counter: Optional custom :class:`TokenCounter`. Defaults + to :func:`default_token_counter` (chars/4 approximation). + default_request: Tokens to reserve per :meth:`allocate` call + when the caller does not pass an explicit ``requested`` + amount. Must be > 0. + + Raises: + BudgetConfigError: If ``total_budget`` or ``default_request`` + is non-positive. + """ + if total_budget <= 0: + raise BudgetConfigError("total_budget must be positive") + if default_request <= 0: + raise BudgetConfigError("default_request must be positive") + self._state = _BudgetState(total=total_budget, used=0, reserved=0) + self._lock = asyncio.Lock() + self._counter: TokenCounter = token_counter or default_token_counter + self._default_request = default_request + + # ── Read-only properties ────────────────────────────────────────────────── + + @property + def total_budget(self) -> int: + """The total session budget configured at construction.""" + return self._state.total + + @property + def remaining(self) -> int: + """Budget remaining after accounting for committed *and* reserved use. + + Reservations are subtracted so that two concurrent :meth:`allocate` + calls cannot both believe the same budget slice is free. + """ + return max(0, self._state.total - self._state.used - self._state.reserved) + + @property + def used(self) -> int: + """Tokens already committed via :meth:`record_usage`.""" + return self._state.used + + @property + def usage_fraction(self) -> float: + """Fraction of the total budget already committed (``used / total``). + + Reservations are *not* counted here — only committed usage. The + value is always in ``[0.0, 1.0]``. + """ + if self._state.total == 0: + return 1.0 + return min(1.0, self._state.used / self._state.total) + + # ── Allocation / recording ──────────────────────────────────────────────── + + async def allocate(self, requested: int | None = None) -> int: + """Reserve a budget slice for an upcoming invocation. + + Args: + requested: Tokens the caller would like to spend. ``None`` uses + the manager's ``default_request``. Negative values are + rejected. + + Returns: + The number of tokens actually reserved + (``min(requested, remaining)``). May be less than ``requested`` + when the budget is nearly exhausted but non-zero. + + Raises: + BudgetExhausted: If no budget remains at all. + BudgetConfigError: If ``requested`` is negative. + """ + if requested is not None and requested < 0: + raise BudgetConfigError("requested must be non-negative") + async with self._lock: + if self.remaining <= 0: + raise BudgetExhausted( + f"Session budget exhausted: used {self._state.used} of " + f"{self._state.total} tokens (no budget remaining)." + ) + want = self._default_request if requested is None else requested + granted = min(want, self.remaining) + self._state.reserved += granted + logger.debug( + "budget_allocate", + extra={ + "requested": want, + "granted": granted, + "remaining": self.remaining, + "used": self._state.used, + }, + ) + return granted + + async def record_usage(self, actual: int, *, reserved: int | None = None) -> None: + """Reconcile a completed invocation against a prior reservation. + + Args: + actual: Actual tokens consumed (computed via + :meth:`count_tokens` on the Frame payload). Negative values + are rejected. + reserved: The value previously returned by :meth:`allocate`. If + omitted, the reservation pool is left untouched and only + ``actual`` is added to ``used``. + + Raises: + BudgetConfigError: If ``actual`` or ``reserved`` is negative. + """ + if actual < 0: + raise BudgetConfigError("actual must be non-negative") + if reserved is not None and reserved < 0: + raise BudgetConfigError("reserved must be non-negative") + async with self._lock: + if reserved is not None: + self._state.reserved = max(0, self._state.reserved - reserved) + self._state.used = min(self._state.total, self._state.used + actual) + logger.debug( + "budget_record", + extra={ + "actual": actual, + "reserved_released": reserved or 0, + "remaining": self.remaining, + "used": self._state.used, + }, + ) + + async def release(self, reserved: int) -> None: + """Release a reservation without recording any usage. + + Called when an invocation fails before the firewall runs (for + example the driver raised :class:`~agent_kernel.errors.DriverError` + or the firewall itself raised). + + Args: + reserved: The amount previously returned by :meth:`allocate`. + + Raises: + BudgetConfigError: If ``reserved`` is negative. + """ + if reserved < 0: + raise BudgetConfigError("reserved must be non-negative") + async with self._lock: + self._state.reserved = max(0, self._state.reserved - reserved) + + # ── Counting / mode suggestion ──────────────────────────────────────────── + + def count_tokens(self, value: Any) -> int: + """Count tokens for *value* using the configured :class:`TokenCounter`.""" + return self._counter(value) + + def suggested_mode(self, requested: ResponseMode) -> ResponseMode: + """Suggest a response mode based on remaining budget. + + Escalation table (issue #44). Boundaries land in the more-conservative + tier, so 50% exactly downgrades raw and 20% exactly floors at summary: + + ================= ============================================== + Budget remaining Suggested mode + ================= ============================================== + > 50% Caller's requested mode + 20% – 50% ``table`` (if caller requested ``raw``) + 5% – 20% (≥ 5%) ``summary`` (or stricter if already requested) + < 5% ``handle_only`` + ================= ============================================== + + The suggestion never *relaxes* a stricter caller-requested mode — + if the caller asked for ``handle_only`` the result is always + ``handle_only``. ``raw`` is downgraded as soon as remaining drops + to 50% or below because raw payloads are unbounded and the + cross-session budget cannot accommodate them. + + Args: + requested: Mode the caller asked for. + + Returns: + Mode the kernel should actually use for the upcoming + invocation. Deterministic — no randomness. + """ + if self._state.total == 0: + return "handle_only" + fraction_remaining = self.remaining / self._state.total + if fraction_remaining < 0.05: + return "handle_only" + if fraction_remaining <= 0.20: + return _stricter(requested, "summary") + if fraction_remaining <= 0.50: + if requested == "raw": + return "table" + return requested + return requested + + +# ── Internal helpers ────────────────────────────────────────────────────────── + + +_MODE_RANK: dict[ResponseMode, int] = { + "raw": 0, + "table": 1, + "summary": 2, + "handle_only": 3, +} + + +def _stricter(requested: ResponseMode, floor: ResponseMode) -> ResponseMode: + """Return whichever of *requested* and *floor* is stricter (higher rank).""" + return requested if _MODE_RANK[requested] >= _MODE_RANK[floor] else floor diff --git a/src/agent_kernel/firewall/budgets.py b/src/agent_kernel/firewall/budgets.py index 9ad367f..0db7fc5 100644 --- a/src/agent_kernel/firewall/budgets.py +++ b/src/agent_kernel/firewall/budgets.py @@ -1,7 +1,9 @@ -"""Budgets dataclass for the context firewall. +"""Per-invocation firewall budget caps. -Canonical definition of :class:`Budgets`. Re-exported via -``agent_kernel.firewall`` and the top-level ``agent_kernel`` package. +Defines :class:`Budgets`, the dataclass enforced by the +:class:`~agent_kernel.firewall.transform.Firewall` when shaping a single +:class:`~agent_kernel.models.Frame`. Cross-invocation cumulative tracking +lives in :mod:`agent_kernel.firewall.budget_manager`. """ from __future__ import annotations diff --git a/src/agent_kernel/firewall/token_counting.py b/src/agent_kernel/firewall/token_counting.py new file mode 100644 index 0000000..8662383 --- /dev/null +++ b/src/agent_kernel/firewall/token_counting.py @@ -0,0 +1,41 @@ +"""Token counting protocol and default character-based approximation. + +The :class:`TokenCounter` protocol lets callers plug in vendor-specific +token counters (for example, a ``tiktoken``-based one) into the +:class:`~agent_kernel.firewall.budget_manager.BudgetManager`. The default +implementation, :func:`default_token_counter`, uses +``len(json.dumps(value, default=str)) // 4`` and has no extra dependencies. +""" + +from __future__ import annotations + +import json +from typing import Any, Protocol + + +class TokenCounter(Protocol): + """Approximates the token cost of an arbitrary value. + + Implementations must be deterministic and side-effect-free. + """ + + def __call__(self, value: Any) -> int: ... + + +def default_token_counter(value: Any) -> int: + """Character-based token approximation (``chars // 4``). + + Args: + value: Any JSON-serialisable value. Non-serialisable values fall back + to ``str(value)``. + + Returns: + A non-negative integer approximating the token count. + """ + if value is None: + return 0 + try: + text = json.dumps(value, default=str) + except (TypeError, ValueError): + text = str(value) + return max(0, len(text) // 4) diff --git a/src/agent_kernel/kernel.py b/src/agent_kernel/kernel.py index 220853d..6c4cb24 100644 --- a/src/agent_kernel/kernel.py +++ b/src/agent_kernel/kernel.py @@ -10,6 +10,7 @@ from .drivers.base import Driver, ExecutionContext from .enums import SafetyClass from .errors import AgentKernelError, DriverError +from .firewall.budget_manager import BudgetManager from .firewall.transform import Firewall from .handles import HandleStore from .models import ( @@ -35,6 +36,22 @@ logger = logging.getLogger(__name__) +def _frame_payload(frame: Frame) -> Any: + """Return the LLM-facing payload from a :class:`Frame` for token counting. + + Only the data the LLM actually sees is counted — facts, table rows, + or raw data. Provenance metadata, action IDs, and handle IDs are + skipped because they are kernel bookkeeping rather than context. + """ + if frame.response_mode == "raw": + return frame.raw_data + if frame.response_mode == "table": + return frame.table_preview + if frame.response_mode == "handle_only": + return None + return frame.facts + + class Kernel: """The central orchestrator for capability-based AI agent security. @@ -62,6 +79,7 @@ def __init__( firewall: Firewall | None = None, handle_store: HandleStore | None = None, trace_store: TraceStore | None = None, + budget_manager: BudgetManager | None = None, ) -> None: self._registry = registry self._policy: PolicyEngine = policy or DefaultPolicyEngine() @@ -70,8 +88,16 @@ def __init__( self._firewall = firewall or Firewall() self._handle_store = handle_store or HandleStore() self._trace_store = trace_store or TraceStore() + self._budget_manager = budget_manager self._drivers: dict[str, Driver] = {} + # ── Budget accessor ──────────────────────────────────────────────────────── + + @property + def budget(self) -> BudgetManager | None: + """The cross-invocation :class:`BudgetManager`, or ``None`` if none is configured.""" + return self._budget_manager + # ── Driver registration ──────────────────────────────────────────────────── def register_driver(self, driver: Driver) -> None: @@ -266,6 +292,12 @@ async def invoke( effective_response_mode: ResponseMode = response_mode if response_mode == "raw" and "admin" not in principal.roles: effective_response_mode = "summary" + # Mirror the BudgetManager escalation an actual invoke would apply, + # so dry-run reports the mode the caller would really see. + if self._budget_manager is not None: + effective_response_mode = self._budget_manager.suggested_mode( + effective_response_mode + ) _cost_map: dict[SafetyClass, Literal["low", "medium", "high"]] = { SafetyClass.READ: "low", SafetyClass.WRITE: "medium", @@ -283,12 +315,38 @@ async def invoke( operation=operation, resolved_args=args, response_mode=effective_response_mode, - budget_remaining=None, + budget_remaining=( + self._budget_manager.remaining if self._budget_manager is not None else None + ), estimated_cost=_cost_map[capability.safety_class], ) action_id = str(uuid.uuid4()) + # ── Mirror Firewall's admin-only ``raw`` gate ───────────────────────── + # The Firewall downgrades raw → summary for non-admin principals + # (see firewall/transform.py and docs/agent-context/invariants.md). + # We must mirror that downgrade *before* deciding whether to store a + # handle and before consulting the budget manager, otherwise a + # non-admin asking for raw would get a summary frame *without* a + # handle (because the kernel skipped handle creation thinking the + # mode was still raw). + effective_mode: ResponseMode = response_mode + if response_mode == "raw" and "admin" not in principal.roles: + effective_mode = "summary" + + # ── Cross-invocation budget allocation & mode escalation ────────────── + # When a BudgetManager is attached, reserve a slice of the cumulative + # session budget before driver execution. The manager raises + # BudgetExhausted if no budget remains. The requested response_mode is + # escalated to a more aggressive tier as the remaining budget shrinks + # (see BudgetManager.suggested_mode). This change is invisible to + # callers without a BudgetManager — the original mode flows through. + reserved_tokens: int | None = None + if self._budget_manager is not None: + reserved_tokens = await self._budget_manager.allocate() + effective_mode = self._budget_manager.suggested_mode(effective_mode) + _log_ctx = { "action_id": action_id, "principal_id": principal.principal_id, @@ -296,7 +354,12 @@ async def invoke( } logger.info( "invoke_start", - extra={**_log_ctx, "token_id": token.token_id, "response_mode": response_mode}, + extra={ + **_log_ctx, + "token_id": token.token_id, + "response_mode": response_mode, + "effective_mode": effective_mode, + }, ) # ── Execute with fallback ───────────────────────────────────────────── @@ -329,6 +392,9 @@ async def invoke( continue if raw_result is None: + # Release any reservation — no tokens were spent by the firewall. + if self._budget_manager is not None and reserved_tokens is not None: + await self._budget_manager.release(reserved_tokens) err_msg = str(last_error) if last_error else "No drivers available." logger.warning("invoke_failure", extra={**_log_ctx, "error": err_msg}) trace = ActionTrace( @@ -349,22 +415,39 @@ async def invoke( # ── Store handle ────────────────────────────────────────────────────── handle: Handle | None = None - if response_mode != "raw": + if effective_mode != "raw": handle = self._handle_store.store( capability_id=token.capability_id, data=raw_result.data, ) - # ── Firewall transform ──────────────────────────────────────────────── - frame = self._firewall.transform( - raw_result, - action_id=action_id, - principal_id=principal.principal_id, - principal_roles=list(principal.roles), - response_mode=response_mode, - constraints=token.constraints, - handle=handle, - ) + # ── Firewall transform + budget reconciliation ──────────────────────── + # Both steps run inside a try/finally so a Firewall exception (e.g. + # malformed constraint inputs) still releases any outstanding budget + # reservation. record_usage replaces the reservation with committed + # usage; the finally branch only fires if we never got there. + reservation_consumed = False + try: + frame = self._firewall.transform( + raw_result, + action_id=action_id, + principal_id=principal.principal_id, + principal_roles=list(principal.roles), + response_mode=effective_mode, + constraints=token.constraints, + handle=handle, + ) + if self._budget_manager is not None and reserved_tokens is not None: + actual_tokens = self._budget_manager.count_tokens(_frame_payload(frame)) + await self._budget_manager.record_usage(actual_tokens, reserved=reserved_tokens) + reservation_consumed = True + finally: + if ( + not reservation_consumed + and self._budget_manager is not None + and reserved_tokens is not None + ): + await self._budget_manager.release(reserved_tokens) # ── Record trace ────────────────────────────────────────────────────── trace = ActionTrace( diff --git a/tests/test_firewall.py b/tests/test_firewall.py index 391afd4..459820a 100644 --- a/tests/test_firewall.py +++ b/tests/test_firewall.py @@ -3,8 +3,17 @@ from __future__ import annotations import datetime +from typing import Any -from agent_kernel import Firewall +import pytest + +from agent_kernel import ( + BudgetConfigError, + BudgetExhausted, + BudgetManager, + Firewall, + default_token_counter, +) from agent_kernel.firewall.budgets import Budgets from agent_kernel.firewall.summarize import summarize from agent_kernel.models import Handle, RawResult @@ -286,3 +295,242 @@ def test_summarize_dict_max_facts() -> None: data = {"a": 1, "b": 2, "c": 3} facts = summarize(data, max_facts=2) assert len(facts) <= 2 + + +# ── Token counting ───────────────────────────────────────────────────────────── + + +def test_default_token_counter_none_is_zero() -> None: + assert default_token_counter(None) == 0 + + +def test_default_token_counter_str_chars_over_four() -> None: + # JSON-encoded form is "hello world" with quotes → 13 chars → 3 tokens. + assert default_token_counter("hello world") == 3 + + +def test_default_token_counter_dict_uses_json_chars() -> None: + value: dict[str, Any] = {"id": 1, "name": "alice"} + # len('{"id": 1, "name": "alice"}') == 26 → 26 // 4 == 6. + assert default_token_counter(value) == 6 + + +def test_default_token_counter_non_json_falls_back_to_str() -> None: + class NotSerialisable: + def __repr__(self) -> str: + return "X" * 100 + + # The repr is 100 chars; default=str → 100 chars → 25 tokens. + # Wrapping in JSON adds two quotes → 102 chars → 25 tokens. + assert default_token_counter(NotSerialisable()) == 25 + + +# ── BudgetManager: construction validation ───────────────────────────────────── + + +def test_budget_manager_rejects_non_positive_total() -> None: + with pytest.raises(BudgetConfigError, match="total_budget must be positive"): + BudgetManager(total_budget=0) + with pytest.raises(BudgetConfigError, match="total_budget must be positive"): + BudgetManager(total_budget=-1) + + +def test_budget_manager_rejects_non_positive_default_request() -> None: + with pytest.raises(BudgetConfigError, match="default_request must be positive"): + BudgetManager(total_budget=100, default_request=0) + + +# ── BudgetManager: allocation / recording ───────────────────────────────────── + + +@pytest.mark.asyncio +async def test_allocate_grants_full_when_under_budget() -> None: + bm = BudgetManager(total_budget=1000) + granted = await bm.allocate(200) + assert granted == 200 + assert bm.remaining == 800 + assert bm.used == 0 # reservation, not commit + + +@pytest.mark.asyncio +async def test_allocate_caps_at_remaining() -> None: + bm = BudgetManager(total_budget=1000) + await bm.allocate(700) # reserve 700 + granted = await bm.allocate(500) + # Only 300 is free after the first reservation. + assert granted == 300 + assert bm.remaining == 0 + + +@pytest.mark.asyncio +async def test_allocate_uses_default_request_when_none() -> None: + bm = BudgetManager(total_budget=1000, default_request=250) + granted = await bm.allocate() + assert granted == 250 + + +@pytest.mark.asyncio +async def test_allocate_rejects_negative_request() -> None: + bm = BudgetManager(total_budget=1000) + with pytest.raises(BudgetConfigError, match="non-negative"): + await bm.allocate(-10) + + +@pytest.mark.asyncio +async def test_allocate_raises_budget_exhausted_when_empty() -> None: + bm = BudgetManager(total_budget=100) + await bm.allocate(100) + await bm.record_usage(100, reserved=100) + with pytest.raises(BudgetExhausted, match="Session budget exhausted"): + await bm.allocate(10) + + +@pytest.mark.asyncio +async def test_record_usage_reconciles_under_reservation() -> None: + bm = BudgetManager(total_budget=1000) + reserved = await bm.allocate(400) + await bm.record_usage(150, reserved=reserved) + # Released the 400 reservation, committed 150 → remaining = 850. + assert bm.used == 150 + assert bm.remaining == 850 + + +@pytest.mark.asyncio +async def test_record_usage_caps_used_at_total() -> None: + # Defensive: actual > total should never push used above total. + bm = BudgetManager(total_budget=100) + await bm.record_usage(999) + assert bm.used == 100 + assert bm.remaining == 0 + + +@pytest.mark.asyncio +async def test_record_usage_rejects_negative() -> None: + bm = BudgetManager(total_budget=1000) + with pytest.raises(BudgetConfigError, match="non-negative"): + await bm.record_usage(-1) + + +@pytest.mark.asyncio +async def test_record_usage_rejects_negative_reserved() -> None: + bm = BudgetManager(total_budget=1000) + with pytest.raises(BudgetConfigError, match="reserved must be non-negative"): + await bm.record_usage(0, reserved=-5) + + +@pytest.mark.asyncio +async def test_release_returns_reservation_to_pool() -> None: + bm = BudgetManager(total_budget=1000) + reserved = await bm.allocate(400) + await bm.release(reserved) + assert bm.remaining == 1000 + assert bm.used == 0 + + +@pytest.mark.asyncio +async def test_release_rejects_negative() -> None: + bm = BudgetManager(total_budget=1000) + with pytest.raises(BudgetConfigError, match="reserved must be non-negative"): + await bm.release(-1) + + +def test_budget_config_error_is_agent_kernel_error() -> None: + """``BudgetConfigError`` is part of the public exception hierarchy.""" + from agent_kernel import AgentKernelError + + assert issubclass(BudgetConfigError, AgentKernelError) + + +# ── BudgetManager: properties ────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_usage_fraction_progresses() -> None: + bm = BudgetManager(total_budget=200) + assert bm.usage_fraction == 0.0 + await bm.record_usage(100) + assert bm.usage_fraction == 0.5 + await bm.record_usage(100) + assert bm.usage_fraction == 1.0 + + +def test_total_budget_reflects_constructor() -> None: + bm = BudgetManager(total_budget=12345) + assert bm.total_budget == 12345 + + +# ── BudgetManager: custom counter ────────────────────────────────────────────── + + +def test_custom_token_counter_is_used() -> None: + calls: list[Any] = [] + + def fake_counter(value: Any) -> int: + calls.append(value) + return 42 + + bm = BudgetManager(total_budget=1000, token_counter=fake_counter) + assert bm.count_tokens({"x": 1}) == 42 + assert calls == [{"x": 1}] + + +# ── BudgetManager: suggested_mode escalation table ───────────────────────────── + + +@pytest.mark.asyncio +async def test_suggested_mode_above_fifty_percent_keeps_requested() -> None: + bm = BudgetManager(total_budget=1000) + # 0% used → 100% remaining + assert bm.suggested_mode("raw") == "raw" + assert bm.suggested_mode("table") == "table" + assert bm.suggested_mode("summary") == "summary" + + +@pytest.mark.asyncio +async def test_suggested_mode_between_twenty_and_fifty_downgrades_raw_only() -> None: + bm = BudgetManager(total_budget=1000) + # Consume 600 → remaining 400 → 40% + await bm.record_usage(600) + assert bm.suggested_mode("raw") == "table" + assert bm.suggested_mode("table") == "table" + assert bm.suggested_mode("summary") == "summary" + assert bm.suggested_mode("handle_only") == "handle_only" + + +@pytest.mark.asyncio +async def test_suggested_mode_between_five_and_twenty_floors_at_summary() -> None: + bm = BudgetManager(total_budget=1000) + # Consume 850 → remaining 150 → 15% + await bm.record_usage(850) + assert bm.suggested_mode("raw") == "summary" + assert bm.suggested_mode("table") == "summary" + assert bm.suggested_mode("summary") == "summary" + assert bm.suggested_mode("handle_only") == "handle_only" + + +@pytest.mark.asyncio +async def test_suggested_mode_under_five_percent_forces_handle_only() -> None: + bm = BudgetManager(total_budget=1000) + # Consume 980 → remaining 20 → 2% + await bm.record_usage(980) + assert bm.suggested_mode("raw") == "handle_only" + assert bm.suggested_mode("table") == "handle_only" + assert bm.suggested_mode("summary") == "handle_only" + assert bm.suggested_mode("handle_only") == "handle_only" + + +@pytest.mark.asyncio +async def test_suggested_mode_boundary_exactly_fifty_percent_downgrades() -> None: + # The boundary is strict-less-than, so remaining == 50% sits in the + # 20–50% bucket and downgrades raw. + bm = BudgetManager(total_budget=1000) + await bm.record_usage(500) + assert bm.suggested_mode("raw") == "table" + + +@pytest.mark.asyncio +async def test_suggested_mode_boundary_exactly_twenty_percent_floors_summary() -> None: + bm = BudgetManager(total_budget=1000) + await bm.record_usage(800) + assert bm.suggested_mode("raw") == "summary" + assert bm.suggested_mode("table") == "summary" diff --git a/tests/test_kernel.py b/tests/test_kernel.py index 4b974b9..10a69fd 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -440,3 +440,295 @@ def evaluate( req = CapabilityRequest(capability_id="billing.list_invoices", goal="t") with pytest.raises(AgentKernelError, match="does not implement explain"): k.explain_denial(req, reader_principal) + + +# ── Cross-invocation budget manager (#44) ───────────────────────────────────── + + +def _kernel_with_budget( + registry: CapabilityRegistry, + memory_driver: InMemoryDriver, + *, + total_budget: int, + default_request: int = 4_000, +) -> Kernel: + """Helper: construct a kernel wired with a BudgetManager.""" + from agent_kernel import BudgetManager + + router = StaticRouter( + routes={ + "billing.list_invoices": ["memory"], + "billing.get_invoice": ["memory"], + "billing.summarize_spend": ["memory"], + "billing.update_invoice": ["memory"], + "billing.delete_invoice": ["memory"], + } + ) + k = Kernel( + registry=registry, + token_provider=HMACTokenProvider(secret="test-secret-do-not-use-in-prod"), + router=router, + budget_manager=BudgetManager( + total_budget=total_budget, + default_request=default_request, + ), + ) + k.register_driver(memory_driver) + return k + + +@pytest.mark.asyncio +async def test_budget_manager_records_usage_across_invocations( + registry: CapabilityRegistry, + memory_driver: InMemoryDriver, + reader_principal: Principal, +) -> None: + """Each invocation must move ``remaining`` strictly downward.""" + k = _kernel_with_budget(registry, memory_driver, total_budget=10_000) + assert k.budget is not None + + req = CapabilityRequest(capability_id="billing.list_invoices", goal="t") + token = k.get_token(req, reader_principal, justification="") + + initial = k.budget.remaining + await k.invoke(token, principal=reader_principal, args={"operation": "billing.list_invoices"}) + after_first = k.budget.remaining + assert after_first < initial + assert k.budget.used > 0 + + # A second invocation consumes more. + await k.invoke(token, principal=reader_principal, args={"operation": "billing.list_invoices"}) + after_second = k.budget.remaining + assert after_second < after_first + + +@pytest.mark.asyncio +async def test_budget_manager_escalates_mode_when_remaining_under_five_percent( + registry: CapabilityRegistry, + memory_driver: InMemoryDriver, + reader_principal: Principal, +) -> None: + """When remaining drops below 5%, even ``summary`` escalates to ``handle_only``.""" + from agent_kernel import BudgetManager + + router = StaticRouter(routes={"billing.list_invoices": ["memory"]}) + bm = BudgetManager(total_budget=1000) + # Pre-consume to push remaining under 5% before the invoke. + await bm.record_usage(980) + k = Kernel( + registry=registry, + token_provider=HMACTokenProvider(secret="test-secret-do-not-use-in-prod"), + router=router, + budget_manager=bm, + ) + k.register_driver(memory_driver) + + req = CapabilityRequest(capability_id="billing.list_invoices", goal="t") + token = k.get_token(req, reader_principal, justification="") + frame = await k.invoke( + token, + principal=reader_principal, + args={"operation": "billing.list_invoices"}, + response_mode="summary", + ) + assert frame.response_mode == "handle_only" + + +@pytest.mark.asyncio +async def test_budget_manager_exhausted_raises_before_driver_runs( + registry: CapabilityRegistry, + memory_driver: InMemoryDriver, + reader_principal: Principal, +) -> None: + """An exhausted budget surfaces ``BudgetExhausted`` and skips the driver.""" + from agent_kernel import BudgetExhausted, BudgetManager + + router = StaticRouter(routes={"billing.list_invoices": ["memory"]}) + bm = BudgetManager(total_budget=100) + await bm.record_usage(100) # Drive remaining to 0. + k = Kernel( + registry=registry, + token_provider=HMACTokenProvider(secret="test-secret-do-not-use-in-prod"), + router=router, + budget_manager=bm, + ) + k.register_driver(memory_driver) + + req = CapabilityRequest(capability_id="billing.list_invoices", goal="t") + token = k.get_token(req, reader_principal, justification="") + with pytest.raises(BudgetExhausted, match="Session budget exhausted"): + await k.invoke( + token, + principal=reader_principal, + args={"operation": "billing.list_invoices"}, + ) + + +@pytest.mark.asyncio +async def test_budget_manager_releases_reservation_on_driver_failure( + registry: CapabilityRegistry, + reader_principal: Principal, +) -> None: + """When all drivers fail, the reserved tokens must return to the pool.""" + from agent_kernel import BudgetManager + + # Construct a kernel with a router pointing at a driver that does not exist. + router = StaticRouter(routes={"billing.list_invoices": ["nope"]}) + bm = BudgetManager(total_budget=1000, default_request=200) + k = Kernel( + registry=registry, + token_provider=HMACTokenProvider(secret="test-secret-do-not-use-in-prod"), + router=router, + budget_manager=bm, + ) + + req = CapabilityRequest(capability_id="billing.list_invoices", goal="t") + token = k.get_token(req, reader_principal, justification="") + with pytest.raises(DriverError): + await k.invoke( + token, + principal=reader_principal, + args={"operation": "billing.list_invoices"}, + ) + # Budget was reserved then released — nothing committed. + assert bm.remaining == 1000 + assert bm.used == 0 + + +@pytest.mark.asyncio +async def test_dry_run_reports_budget_remaining_when_manager_configured( + registry: CapabilityRegistry, + memory_driver: InMemoryDriver, + reader_principal: Principal, +) -> None: + """DryRunResult.budget_remaining is populated when a BudgetManager is wired.""" + from agent_kernel.models import DryRunResult + + k = _kernel_with_budget(registry, memory_driver, total_budget=10_000) + req = CapabilityRequest(capability_id="billing.list_invoices", goal="t") + token = k.get_token(req, reader_principal, justification="") + result = await k.invoke( + token, + principal=reader_principal, + args={"operation": "billing.list_invoices"}, + dry_run=True, + ) + assert isinstance(result, DryRunResult) + assert result.budget_remaining == 10_000 # Dry-run does not commit. + + +@pytest.mark.asyncio +async def test_dry_run_reflects_escalated_mode_under_budget_pressure( + registry: CapabilityRegistry, + memory_driver: InMemoryDriver, + reader_principal: Principal, +) -> None: + """Dry-run mirrors the BudgetManager escalation that a real invoke would apply.""" + from agent_kernel import BudgetManager + from agent_kernel.models import DryRunResult + + router = StaticRouter(routes={"billing.list_invoices": ["memory"]}) + bm = BudgetManager(total_budget=1000) + await bm.record_usage(980) # < 5% remaining. + k = Kernel( + registry=registry, + token_provider=HMACTokenProvider(secret="test-secret-do-not-use-in-prod"), + router=router, + budget_manager=bm, + ) + k.register_driver(memory_driver) + + req = CapabilityRequest(capability_id="billing.list_invoices", goal="t") + token = k.get_token(req, reader_principal, justification="") + result = await k.invoke( + token, + principal=reader_principal, + args={"operation": "billing.list_invoices"}, + response_mode="table", + dry_run=True, + ) + assert isinstance(result, DryRunResult) + assert result.response_mode == "handle_only" + + +@pytest.mark.asyncio +async def test_budget_manager_releases_reservation_on_firewall_failure( + registry: CapabilityRegistry, + memory_driver: InMemoryDriver, + reader_principal: Principal, +) -> None: + """Firewall raising after a reservation must release tokens to the pool. + + Without the finally block the reservation would stay locked, permanently + eroding the cumulative budget on every transform failure. + """ + from agent_kernel import BudgetManager, FirewallError + + class FailingFirewall: + def transform(self, *args: object, **kwargs: object) -> object: + raise FirewallError("simulated firewall failure") + + router = StaticRouter(routes={"billing.list_invoices": ["memory"]}) + bm = BudgetManager(total_budget=1000, default_request=200) + k = Kernel( + registry=registry, + token_provider=HMACTokenProvider(secret="test-secret-do-not-use-in-prod"), + router=router, + firewall=FailingFirewall(), # type: ignore[arg-type] + budget_manager=bm, + ) + k.register_driver(memory_driver) + + req = CapabilityRequest(capability_id="billing.list_invoices", goal="t") + token = k.get_token(req, reader_principal, justification="") + with pytest.raises(FirewallError, match="simulated firewall failure"): + await k.invoke( + token, + principal=reader_principal, + args={"operation": "billing.list_invoices"}, + ) + # Reservation was released; nothing committed because the frame never landed. + assert bm.remaining == 1000 + assert bm.used == 0 + + +@pytest.mark.asyncio +async def test_non_admin_raw_request_receives_handle_and_downgraded_frame( + kernel: Kernel, + reader_principal: Principal, +) -> None: + """A non-admin asking for ``raw`` must get a handle + a non-raw frame. + + Before the admin-gate fix the kernel left ``effective_mode == "raw"``, + skipped handle creation, and the firewall then downgraded to summary — + yielding a summary frame *without* a handle. The fix mirrors the + Firewall's admin gate inside the kernel so the handle is always stored. + """ + req = CapabilityRequest(capability_id="billing.list_invoices", goal="t") + token = kernel.get_token(req, reader_principal, justification="") + frame = await kernel.invoke( + token, + principal=reader_principal, + args={"operation": "billing.list_invoices"}, + response_mode="raw", + ) + assert frame.response_mode != "raw" # downgraded + assert frame.handle is not None # handle was stored despite the raw request + + +@pytest.mark.asyncio +async def test_kernel_without_budget_manager_behaves_identically( + kernel: Kernel, + reader_principal: Principal, +) -> None: + """Backward-compat: the default kernel has ``kernel.budget is None``.""" + assert kernel.budget is None + req = CapabilityRequest(capability_id="billing.list_invoices", goal="t") + token = kernel.get_token(req, reader_principal, justification="") + frame = await kernel.invoke( + token, + principal=reader_principal, + args={"operation": "billing.list_invoices"}, + ) + # No escalation happens — requested mode flows through. + assert frame.response_mode == "summary"