diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpRequestHandler.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpRequestHandler.java new file mode 100644 index 000000000..79890dd65 --- /dev/null +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpRequestHandler.java @@ -0,0 +1,52 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.mcp.server; + +import java.util.function.Consumer; +import software.amazon.smithy.java.mcp.model.JsonRpcRequest; +import software.amazon.smithy.java.mcp.model.JsonRpcResponse; +import software.amazon.smithy.utils.SmithyUnstableApi; + +/** + * Handler for MCP JSON-RPC requests. Used as the {@code next} parameter in + * {@link McpRequestInterceptor} chains, with {@link McpService} as the terminal handler. + * + *

This interface mirrors the contract of {@link McpService#handleRequest}. Responses + * are delivered through one of two channels depending on the request type: + * + *

+ * + *

Interceptors must handle all three cases. See {@link McpRequestInterceptor} for patterns. + */ +@SmithyUnstableApi +@FunctionalInterface +public interface McpRequestHandler { + + /** + * Handles a JSON-RPC request. + * + * @param request The JSON-RPC request to handle + * @param asyncResponseCallback Callback for async responses (proxy tool calls). + * Interceptors that need to observe async responses should wrap this callback. + * @param protocolVersion The protocol version for this request + * @return The JSON-RPC response for synchronous operations, or {@code null} for + * async operations, notifications, and unknown methods + */ + JsonRpcResponse handleRequest( + JsonRpcRequest request, + Consumer asyncResponseCallback, + ProtocolVersion protocolVersion + ); +} diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpRequestInterceptor.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpRequestInterceptor.java new file mode 100644 index 000000000..173e03390 --- /dev/null +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpRequestInterceptor.java @@ -0,0 +1,95 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.mcp.server; + +import java.util.function.Consumer; +import software.amazon.smithy.java.mcp.model.JsonRpcRequest; +import software.amazon.smithy.java.mcp.model.JsonRpcResponse; +import software.amazon.smithy.utils.SmithyUnstableApi; + +/** + * Interceptor for MCP JSON-RPC request handling. Interceptors form a chain with + * {@link McpService} as the terminal handler. Each interceptor can observe, modify, + * or short-circuit request handling. + * + *

Response contract

+ * + *

Responses are delivered through one of two channels. Interceptors must handle both: + * + *

+ * + *

To observe or modify async responses, wrap the callback before passing it to + * {@code next}. To handle both paths, instrument the return value and the callback: + * + *

Example: telemetry interceptor

+ *
{@code
+ * McpService.builder()
+ *     .addInterceptor((request, callback, version, next) -> {
+ *         long start = System.nanoTime();
+ *         // Wrap callback to observe async responses (proxy tool calls)
+ *         Consumer wrapped = response -> {
+ *             emitMetrics(request, response, System.nanoTime() - start);
+ *             callback.accept(response);
+ *         };
+ *         var response = next.handleRequest(request, wrapped, version);
+ *         // Handle sync responses (ping, tools/list, local tool calls, etc.)
+ *         if (response != null) {
+ *             emitMetrics(request, response, System.nanoTime() - start);
+ *         }
+ *         return response;
+ *     })
+ *     .services(services)
+ *     .build();
+ * }
+ * + *

Example: short-circuit interceptor

+ *
{@code
+ * .addInterceptor((request, callback, version, next) -> {
+ *     if (isBlocked(request)) {
+ *         return JsonRpcResponse.builder()
+ *             .id(request.getId())
+ *             .error(JsonRpcErrorResponse.builder()
+ *                 .code(403)
+ *                 .message("Blocked")
+ *                 .build())
+ *             .jsonrpc("2.0")
+ *             .build();
+ *     }
+ *     return next.handleRequest(request, callback, version);
+ * })
+ * }
+ * + *

Interceptors are invoked in the order they are added. The last interceptor's + * {@code next} parameter delegates to {@link McpService}. + */ +@SmithyUnstableApi +@FunctionalInterface +public interface McpRequestInterceptor { + + /** + * Intercepts an MCP JSON-RPC request. + * + * @param request The JSON-RPC request + * @param asyncResponseCallback Callback for async responses. Wrap this to observe + * or modify async proxy responses before they reach the transport. + * @param protocolVersion The protocol version for this request + * @param next The next handler in the chain (ultimately McpService) + * @return The JSON-RPC response for synchronous operations, or {@code null} for + * async operations, notifications, and unknown methods + */ + JsonRpcResponse intercept( + JsonRpcRequest request, + Consumer asyncResponseCallback, + ProtocolVersion protocolVersion, + McpRequestHandler next + ); +} diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerBuilder.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerBuilder.java index fa40be0db..1fa81ea5f 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerBuilder.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerBuilder.java @@ -24,6 +24,7 @@ public final class McpServerBuilder { OutputStream os; Map services = new HashMap<>(); List proxyList = new ArrayList<>(); + List interceptors = new ArrayList<>(); String name; String version; ToolFilter toolFilter = (server, tool) -> true; @@ -61,7 +62,7 @@ public McpServerBuilder version(String version) { public Server build() { validate(); // Create McpService before building McpServer - var builder = McpService.builder() + var serviceBuilder = McpService.builder() .services(services) .proxyList(proxyList) .name(name != null ? name : "mcp-server") @@ -69,10 +70,14 @@ public Server build() { .metricsObserver(metricsObserver); if (version != null) { - builder.version(version); + serviceBuilder.version(version); } - this.mcpService = builder.build(); + for (var interceptor : interceptors) { + serviceBuilder.addInterceptor(interceptor); + } + + this.mcpService = serviceBuilder.build(); return new McpServer(this); } @@ -101,6 +106,17 @@ public McpServerBuilder metricsObserver(McpMetricsObserver observer) { return this; } + /** + * Adds a request interceptor to the chain. Interceptors are invoked in the order + * they are added, with {@link McpService} as the terminal handler. + * + * @see McpRequestInterceptor for usage patterns and the response contract + */ + public McpServerBuilder addInterceptor(McpRequestInterceptor interceptor) { + this.interceptors.add(Objects.requireNonNull(interceptor, "interceptor")); + return this; + } + private void validate() { Objects.requireNonNull(is, "MCP server input stream is required"); Objects.requireNonNull(os, "MCP server output stream is required"); diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerProxy.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerProxy.java index 12722c1e6..22e61867f 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerProxy.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerProxy.java @@ -101,11 +101,11 @@ protected final ProtocolVersion getProtocolVersion() { return protocolVersion.get(); } - abstract CompletableFuture rpc(JsonRpcRequest request); + protected abstract CompletableFuture rpc(JsonRpcRequest request); - abstract void start(); + protected abstract void start(); - abstract CompletableFuture shutdown(); + protected abstract CompletableFuture shutdown(); protected CompletableFuture rpc(String method, ShapeBuilder builder) { JsonRpcRequest request = JsonRpcRequest.builder() diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java index 662020f8a..b2ec133ee 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java @@ -20,6 +20,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.concurrent.CompletionException; import java.util.concurrent.ConcurrentHashMap; @@ -95,6 +96,7 @@ public final class McpService { private final AtomicReference proxiesInitialized = new AtomicReference<>(false); private final McpMetricsObserver metricsObserver; private final SchemaIndex schemaIndex; + private final McpRequestHandler requestHandler; private Consumer notificationWriter; McpService( @@ -103,7 +105,8 @@ public final class McpService { String name, String version, ToolFilter toolFilter, - McpMetricsObserver metricsObserver + McpMetricsObserver metricsObserver, + List interceptors ) { this.services = services; this.schemaIndex = @@ -115,22 +118,53 @@ public final class McpService { this.proxies = proxyList.stream().collect(Collectors.toMap(McpServerProxy::name, p -> p)); this.toolFilter = toolFilter; this.metricsObserver = metricsObserver; + this.requestHandler = buildInterceptorChain(interceptors); + } + + private McpRequestHandler buildInterceptorChain(List interceptors) { + McpRequestHandler handler = this::handleRequestInternal; + // Wrap in reverse order so first-added interceptor is outermost + for (int i = interceptors.size() - 1; i >= 0; i--) { + var interceptor = interceptors.get(i); + var next = handler; + handler = (req, callback, version) -> interceptor.intercept(req, callback, version, next); + } + return handler; } /** - * Handles a JSON-RPC request synchronously and returns a response. - * For proxy tool calls, the response callback is invoked asynchronously and this method returns null. - * For local operations, the response is returned immediately. + * Handles a JSON-RPC request, dispatching through the interceptor chain if any + * interceptors are registered. Responses are delivered through one of two channels: + * + *

* * @param req The JSON-RPC request to handle * @param asyncResponseCallback Callback for async responses (used for proxy calls) * @param protocolVersion The protocol version for this request (may be null) - * @return The response for synchronous operations, or null for async operations + * @return The response for synchronous operations, or null for async/notification operations */ public JsonRpcResponse handleRequest( JsonRpcRequest req, Consumer asyncResponseCallback, ProtocolVersion protocolVersion + ) { + return requestHandler.handleRequest(req, asyncResponseCallback, protocolVersion); + } + + /** + * Internal request handling logic. This is the terminal handler in the interceptor chain. + */ + private JsonRpcResponse handleRequestInternal( + JsonRpcRequest req, + Consumer asyncResponseCallback, + ProtocolVersion protocolVersion ) { try { validate(req); @@ -1165,6 +1199,7 @@ public static Builder builder() { public static class Builder { private Map services = new HashMap<>(); private List proxyList = new ArrayList<>(); + private List interceptors = new ArrayList<>(); private String name = "mcp-server"; private String version = "1.0.0"; private ToolFilter toolFilter = (serverId, toolName) -> true; @@ -1200,8 +1235,19 @@ public Builder metricsObserver(McpMetricsObserver metricsObserver) { return this; } + /** + * Adds a request interceptor to the chain. Interceptors are invoked in the order + * they are added, with {@link McpService} as the terminal handler. + * + * @see McpRequestInterceptor for usage patterns and the response contract + */ + public Builder addInterceptor(McpRequestInterceptor interceptor) { + this.interceptors.add(Objects.requireNonNull(interceptor, "interceptor")); + return this; + } + public McpService build() { - return new McpService(services, proxyList, name, version, toolFilter, metricsObserver); + return new McpService(services, proxyList, name, version, toolFilter, metricsObserver, interceptors); } } } diff --git a/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java b/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java index aef941dec..242f389bd 100644 --- a/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java +++ b/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java @@ -1695,7 +1695,7 @@ public List listPrompts() { } @Override - CompletableFuture rpc(JsonRpcRequest request) { + protected CompletableFuture rpc(JsonRpcRequest request) { // Notifications have no ID if (request.getId() == null) { sentNotifications.add(request.getMethod()); @@ -1714,10 +1714,10 @@ List getSentNotifications() { } @Override - void start() {} + protected void start() {} @Override - CompletableFuture shutdown() { + protected CompletableFuture shutdown() { return CompletableFuture.completedFuture(null); } @@ -1730,4 +1730,265 @@ void sendNotification(JsonRpcRequest notification) { notify(notification); } } + + @Test + void testInterceptorSeesRequestAndResponse() { + var interceptedRequests = new ArrayList(); + var interceptedResponses = new ArrayList(); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .addInterceptor((request, callback, version, next) -> { + interceptedRequests.add(request.getMethod()); + var response = next.handleRequest(request, callback, version); + interceptedResponses.add(response); + return response; + }) + .build(); + + server.start(); + write("ping", Document.of(Map.of())); + read(); + + assertEquals(1, interceptedRequests.size()); + assertEquals("ping", interceptedRequests.get(0)); + assertEquals(1, interceptedResponses.size()); + assertNotNull(interceptedResponses.get(0)); + } + + @Test + void testMultipleInterceptorsChainInOrder() { + var callOrder = new ArrayList(); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .addInterceptor((request, callback, version, next) -> { + callOrder.add("A-before"); + var response = next.handleRequest(request, callback, version); + callOrder.add("A-after"); + return response; + }) + .addInterceptor((request, callback, version, next) -> { + callOrder.add("B-before"); + var response = next.handleRequest(request, callback, version); + callOrder.add("B-after"); + return response; + }) + .build(); + + server.start(); + write("ping", Document.of(Map.of())); + read(); + + assertEquals(List.of("A-before", "B-before", "B-after", "A-after"), callOrder); + } + + @Test + void testInterceptorCanShortCircuit() { + var shortCircuitResponse = JsonRpcResponse.builder() + .id(Document.of(0)) + .result(Document.of(Map.of("short", Document.of("circuited")))) + .jsonrpc("2.0") + .build(); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .addInterceptor((request, callback, version, next) -> shortCircuitResponse) + .build(); + + server.start(); + write("ping", Document.of(Map.of())); + var response = read(); + + assertEquals("circuited", response.getResult().getMember("short").asString()); + } + + @Test + void testInterceptorWorksWithToolCalls() { + var interceptedMethods = new ArrayList(); + var responseReceived = new AtomicReference(); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .addInterceptor((request, callback, version, next) -> { + interceptedMethods.add(request.getMethod()); + var response = next.handleRequest(request, callback, version); + responseReceived.set(response); + return response; + }) + .build(); + + server.start(); + initializeWithProtocolVersion(null); + + // tools/list should go through the interceptor + write("tools/list", Document.of(Map.of())); + read(); + + assertTrue(interceptedMethods.contains("tools/list")); + assertNotNull(responseReceived.get()); + } + + @Test + void testNotificationPassesThroughInterceptor() { + var interceptedMethods = new ArrayList(); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .addInterceptor((request, callback, version, next) -> { + interceptedMethods.add(request.getMethod()); + return next.handleRequest(request, callback, version); + }) + .build(); + + server.start(); + + // Send a notification (no id) — should flow through interceptor, produce no output + writeNotification("notifications/initialized", Document.of(Map.of())); + output.assertNoOutput(); + + // Verify server is still alive by sending a real request after the notification + write("ping", Document.of(Map.of())); + var response = read(); + assertNotNull(response.getResult()); + + // The interceptor should have seen both the notification and the ping + assertTrue(interceptedMethods.contains("notifications/initialized")); + assertTrue(interceptedMethods.contains("ping")); + } + + @Test + void testInterceptorWorksWithProxyToolCalls() { + var interceptedMethods = new ArrayList(); + var mockProxy = new CacheTestProxy(new AtomicInteger(0)); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .addService(mockProxy) + .addInterceptor((request, callback, version, next) -> { + interceptedMethods.add(request.getMethod()); + return next.handleRequest(request, callback, version); + }) + .build(); + + server.start(); + initializeWithProtocolVersion(null); + + // tools/list goes through the interceptor and includes proxy tools + write("tools/list", Document.of(Map.of())); + var response = read(); + var tools = response.getResult().asStringMap().get("tools").asList(); + var toolNames = tools.stream() + .map(t -> t.asStringMap().get("name").asString()) + .toList(); + assertTrue(toolNames.contains("test-tool"), "Proxy tool should be listed"); + + // Call the proxy tool through the interceptor chain + write("tools/call", + Document.of(Map.of( + "name", + Document.of("test-tool"), + "arguments", + Document.of(Map.of())))); + response = read(); + assertNotNull(response); + assertNull(response.getError()); + + assertTrue(interceptedMethods.contains("tools/call")); + } + + @Test + void testInterceptorCanWrapAsyncCallback() { + var asyncMethodsCaptured = new ArrayList(); + var mockProxy = new CacheTestProxy(new AtomicInteger(0)); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .addService(mockProxy) + .addInterceptor((request, callback, version, next) -> { + // Wrap callback to observe async responses + var wrapped = new java.util.function.Consumer() { + @Override + public void accept(JsonRpcResponse response) { + asyncMethodsCaptured.add(request.getMethod()); + callback.accept(response); + } + }; + return next.handleRequest(request, wrapped, version); + }) + .build(); + + server.start(); + initializeWithProtocolVersion(null); + + // Call the proxy tool — response comes via async callback + write("tools/call", + Document.of(Map.of( + "name", + Document.of("test-tool"), + "arguments", + Document.of(Map.of())))); + var response = read(); + assertNotNull(response); + assertNull(response.getError()); + + // The wrapped callback should have captured the async method + assertTrue(asyncMethodsCaptured.contains("tools/call")); + } }