Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ client = ["httpx[http2]"]
adk = ["google-adk>=1.20.0"]
openai = ["openai-agents>=0.6.1"]
pydantic_ai = ["pydantic-ai-slim>=1.68.0"]
langchain = ["langchain>=1.0.0", "langgraph>=1.0.0", "langchain-core>=1.0.0"]
tracing = ["opentelemetry-api>=1.36.0"]

[build-system]
Expand Down
54 changes: 54 additions & 0 deletions python/restate/ext/langchain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
"""Restate integration for LangChain agents.

Pass `RestateMiddleware()` to `create_agent(..., middleware=[...])` and run
the agent inside a Restate handler. LLM responses are journaled, so retries
replay them instead of re-calling the model. To make tool side effects
durable, wrap them with `restate_context().run_typed("name", ...)` inside
the tool body.
"""

import typing

from restate import Context, ObjectContext
from restate.server_context import current_context

from ._middleware import RestateMiddleware


def restate_context() -> Context:
"""Return the current Restate Context.

Use this inside a tool body to wrap your side effects in
`ctx.run_typed("name", ...)` — that's the explicit way to make them
durable. The middleware does NOT auto-wrap tool calls.
"""
ctx = current_context()
if ctx is None:
raise RuntimeError("No Restate context found.")
return ctx


def restate_object_context() -> ObjectContext:
"""Return the current Restate ObjectContext. Errors if the agent is not
running inside a Virtual Object handler."""
ctx = current_context()
if ctx is None:
raise RuntimeError("No Restate context found.")
return typing.cast(ObjectContext, ctx)


__all__ = [
"RestateMiddleware",
"restate_context",
"restate_object_context",
]
128 changes: 128 additions & 0 deletions python/restate/ext/langchain/_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#
# Copyright (c) 2023-2026 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
"""LangChain agent middleware that makes a `create_agent` agent durable on Restate.

- `awrap_model_call` journals each LLM response so retries replay it from the
journal instead of re-calling the model.
- `awrap_tool_call` runs parallel tool calls one at a time (via a turnstile
keyed on `tool_call_id`) so any `ctx.run_typed(...)` calls users place in
tool bodies appear in the journal in a stable order across replays.

The middleware does not journal tool calls itself and does not catch
exceptions — wrap side effects explicitly with `restate_context().run_typed(...)`
inside the tool body.
"""

from dataclasses import asdict
from typing import Any, Awaitable, Callable, Optional, cast

from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.messages import AIMessage, AnyMessage, ToolMessage, BaseMessage
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
from pydantic import BaseModel

from restate import RunOptions
from restate.extensions import current_context
from restate.ext.turnstile import Turnstile

from ._state import current_state

ToolCallResult = ToolMessage | Command


class SerializableModelResponse(BaseModel):
"""Serializable mirror of `ModelResponse`.

`result` uses `list[AnyMessage]` (a discriminated union)
so AIMessage `tool_calls` survives serialization.
`BaseMessage`, as on `ModelResponse`, would not.
"""

result: list[AnyMessage]
structured_response: Optional[Any] = None


class RestateMiddleware(AgentMiddleware):
"""Drop-in middleware that makes a `create_agent` agent durable on Restate.

Pass it to `create_agent(..., middleware=[RestateMiddleware()])` and run
the agent inside a Restate handler. LLM responses are journaled; parallel
tool calls are linearized for deterministic replay.

Args:
run_options: forwarded to the LLM `ctx.run_typed` call (max attempts,
retry intervals, ...). `serde` is set internally.
"""

def __init__(self, run_options: Optional[RunOptions[Any]] = None):
super().__init__()
self._options: RunOptions[Any] = run_options or RunOptions()

async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse:
ctx = current_context()
if ctx is None:
raise RuntimeError(
"RestateMiddleware must run inside a Restate handler. "
"Call agent.ainvoke(...) from a handler that exposes a Restate Context."
)

async def call_model() -> SerializableModelResponse:
response = await handler(request)
return SerializableModelResponse(**asdict(response))

journaled = await ctx.run_typed("LLM call", call_model, self._options)

# If the request asked for a Pydantic schema, restore the type.
structured_response = journaled.structured_response
schema = getattr(request.response_format, "schema", None)
if structured_response is not None and isinstance(schema, type) and issubclass(schema, BaseModel):
structured_response = schema.model_validate(structured_response)

# Force tools to run sequentially by setting a turnstile.
# Avoids asyncio.gather() from running in parallel.
ai_message = next((m for m in journaled.result if isinstance(m, AIMessage)), None)
if ai_message:
tool_call_ids = [tid for tc in (ai_message.tool_calls or []) if (tid := tc.get("id")) is not None]
current_state().turnstile = Turnstile(tool_call_ids)

# Turn into ModelResponse as expected by the agent
return ModelResponse(
result=cast(list[BaseMessage], journaled.result),
structured_response=structured_response,
)

async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolCallResult]],
) -> ToolCallResult:
tool_call = request.tool_call
tool_call_id: Optional[str] = tool_call.get("id") if isinstance(tool_call, dict) else None
if tool_call_id is None:
return await handler(request)

# Wait for turn and then execute
turnstile = current_state().turnstile
try:
await turnstile.wait_for(tool_call_id)
result = await handler(request)
turnstile.allow_next_after(tool_call_id)
return result
except BaseException:
# Unblock the rest of the parallel tool batch, then propagate.
turnstile.cancel_all_after(tool_call_id)
raise
28 changes: 28 additions & 0 deletions python/restate/ext/langchain/_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#

from contextvars import ContextVar

from restate.ext.turnstile import Turnstile


class _State:
__slots__ = ("turnstile",)

def __init__(self) -> None:
self.turnstile: Turnstile = Turnstile([])


_state_var: ContextVar[_State] = ContextVar("restate_langchain_state", default=_State())


def current_state() -> _State:
return _state_var.get()
1 change: 0 additions & 1 deletion python/restate/ext/pydantic/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from collections.abc import AsyncIterable, AsyncIterator, Iterator, Sequence
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
from datetime import timedelta
from typing import Any, overload

from restate import RunOptions, TerminalError
Expand Down
Loading
Loading