From 164d194c47b43edeab4a55166bcb06c3ca07d2a3 Mon Sep 17 00:00:00 2001 From: Brian Fox <878612+onematchfox@users.noreply.github.com> Date: Tue, 12 May 2026 15:41:08 +0200 Subject: [PATCH] fix(controller): recover MCP auth session from `RequestExtra` in tool handlers The Go MCP SDK detaches the HTTP request context before dispatching to tool handlers. From the [SDK source](https://github.com/modelcontextprotocol/go-sdk/blob/v1.5.0/mcp/streamable.go#L485-L487): > // Pass req.Context() here, to allow middleware to add context values. > // The context is detached in the jsonrpc2 library when handling the > // long-running stream. This means the auth session placed by `AuthnMiddleware` is not visible via `auth.AuthSessionFrom(ctx)` in tool handlers. The SDK does preserve the original HTTP headers in [RequestExtra.Header](https://github.com/modelcontextprotocol/go-sdk/blob/v1.5.0/mcp/streamable.go#L1155-L1158) though. Re-authenticate from those headers at the top of handleInvokeAgent so the A2A client's outbound request to the agent carries the user's JWT. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Brian Fox <878612+onematchfox@users.noreply.github.com> --- go/core/internal/mcp/mcp_handler.go | 10 ++ go/core/internal/mcp/mcp_handler_test.go | 175 +++++++++++++++++++++++ 2 files changed, 185 insertions(+) create mode 100644 go/core/internal/mcp/mcp_handler_test.go diff --git a/go/core/internal/mcp/mcp_handler.go b/go/core/internal/mcp/mcp_handler.go index 8182df6fb..6b157a3e7 100644 --- a/go/core/internal/mcp/mcp_handler.go +++ b/go/core/internal/mcp/mcp_handler.go @@ -183,6 +183,16 @@ func (h *MCPHandler) handleListAgents(ctx context.Context, req *mcpsdk.CallToolR func (h *MCPHandler) handleInvokeAgent(ctx context.Context, req *mcpsdk.CallToolRequest, input InvokeAgentInput) (*mcpsdk.CallToolResult, InvokeAgentOutput, error) { log := ctrllog.FromContext(ctx).WithName("mcp-handler").WithValues("tool", "invoke_agent") + // The Go MCP SDK detaches the HTTP request context when dispatching to + // tool handlers, so auth.AuthSessionFrom(ctx) returns nothing. Recover + // the auth session from the HTTP headers preserved in RequestExtra so + // that the A2A client's outbound request to the agent carries the user's JWT. + if extra := req.GetExtra(); extra != nil { + if session, err := h.authenticator.Authenticate(ctx, extra.Header, nil); err == nil { + ctx = auth.AuthSessionTo(ctx, session) + } + } + // Parse agent reference (namespace/name or just name) agentNS, agentName, ok := strings.Cut(input.Agent, "/") if !ok { diff --git a/go/core/internal/mcp/mcp_handler_test.go b/go/core/internal/mcp/mcp_handler_test.go new file mode 100644 index 000000000..a01561bab --- /dev/null +++ b/go/core/internal/mcp/mcp_handler_test.go @@ -0,0 +1,175 @@ +package mcp + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "sync" + "testing" + "time" + + "github.com/kagent-dev/kagent/go/core/pkg/auth" + mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeSession is a minimal auth.Session for testing. +type fakeSession struct{ principal auth.Principal } + +func (s *fakeSession) Principal() auth.Principal { return s.principal } + +// fakeAuthProvider propagates the incoming Bearer token to upstream requests unchanged. +type fakeAuthProvider struct { + session auth.Session +} + +func (f *fakeAuthProvider) Authenticate(_ context.Context, headers http.Header, _ url.Values) (auth.Session, error) { + if headers.Get("Authorization") != "" { + return f.session, nil + } + return nil, http.ErrNoCookie +} + +func (f *fakeAuthProvider) UpstreamAuth(r *http.Request, _ auth.Session, _ auth.Principal) error { + r.Header.Set("Authorization", "Bearer upstream-token") + return nil +} + +// a2aBackend is a fake A2A server that records the Authorization header of each request. +type a2aBackend struct { + server *httptest.Server + mu sync.Mutex + lastAuthHeader string +} + +func (b *a2aBackend) getLastAuthHeader() string { + b.mu.Lock() + defer b.mu.Unlock() + return b.lastAuthHeader +} + +func newA2ABackend(t *testing.T) *a2aBackend { + t.Helper() + b := &a2aBackend{} + b.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b.mu.Lock() + b.lastAuthHeader = r.Header.Get("Authorization") + b.mu.Unlock() + resp := map[string]any{ + "jsonrpc": "2.0", + "id": "", + "result": map[string]any{ + "kind": "message", + "messageId": "test-msg", + "role": "agent", + "parts": []any{map[string]any{"kind": "text", "text": "hello from agent"}}, + }, + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Errorf("failed to encode fake A2A response: %v", err) + } + })) + t.Cleanup(b.server.Close) + return b +} + +// authRoundTripper injects a fixed Authorization header into every outgoing request. +type authRoundTripper struct { + base http.RoundTripper + token string +} + +func (a *authRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + r = r.Clone(r.Context()) + r.Header.Set("Authorization", "Bearer "+a.token) + return a.base.RoundTrip(r) +} + +// TestInvokeAgent_AuthPropagation exercises the full MCP HTTP stack: +// the MCP client sends a request with an Authorization header, the handler +// recovers the auth session from RequestExtra, and the A2A backend receives +// the token produced by UpstreamAuth. +func TestInvokeAgent_AuthPropagation(t *testing.T) { + // Fake A2A backend — records the Authorization header it receives. + backend := newA2ABackend(t) + + authProvider := &fakeAuthProvider{session: &fakeSession{}} + + // Real MCP handler (kubeClient is nil; invoke_agent does not use it). + mcpHandler, err := NewMCPHandler(nil, backend.server.URL, authProvider, 5*time.Second) + require.NoError(t, err) + + mcpServer := httptest.NewServer(mcpHandler) + t.Cleanup(mcpServer.Close) + + // MCP client whose HTTP transport injects an Authorization header on every request. + transport := &mcpsdk.StreamableClientTransport{ + Endpoint: mcpServer.URL, + HTTPClient: &http.Client{ + Transport: &authRoundTripper{ + base: http.DefaultTransport, + token: "test-token", + }, + }, + DisableStandaloneSSE: true, + } + + ctx := context.Background() + cs, err := mcpsdk.NewClient(&mcpsdk.Implementation{Name: "test", Version: "1.0"}, nil). + Connect(ctx, transport, nil) + require.NoError(t, err) + t.Cleanup(func() { cs.Close() }) + + result, err := cs.CallTool(ctx, &mcpsdk.CallToolParams{ + Name: "invoke_agent", + Arguments: map[string]any{ + "agent": "default/test-agent", + "task": "say hello", + }, + }) + require.NoError(t, err) + assert.False(t, result.IsError, "expected successful tool result, got: %v", result.Content) + assert.Equal(t, "Bearer upstream-token", backend.getLastAuthHeader(), "A2A backend should receive the token produced by UpstreamAuth") +} + +// TestInvokeAgent_NoAuthPropagationWithoutHeader verifies that when the MCP +// client sends no Authorization header, no Authorization header is +// propagated to the A2A backend. +func TestInvokeAgent_NoAuthPropagationWithoutHeader(t *testing.T) { + backend := newA2ABackend(t) + + authProvider := &fakeAuthProvider{session: &fakeSession{}} + + mcpHandler, err := NewMCPHandler(nil, backend.server.URL, authProvider, 5*time.Second) + require.NoError(t, err) + + mcpServer := httptest.NewServer(mcpHandler) + t.Cleanup(mcpServer.Close) + + // No custom transport — requests carry no Authorization header. + transport := &mcpsdk.StreamableClientTransport{ + Endpoint: mcpServer.URL, + DisableStandaloneSSE: true, + } + + ctx := context.Background() + cs, err := mcpsdk.NewClient(&mcpsdk.Implementation{Name: "test", Version: "1.0"}, nil). + Connect(ctx, transport, nil) + require.NoError(t, err) + t.Cleanup(func() { cs.Close() }) + + result, err := cs.CallTool(ctx, &mcpsdk.CallToolParams{ + Name: "invoke_agent", + Arguments: map[string]any{ + "agent": "default/test-agent", + "task": "say hello", + }, + }) + require.NoError(t, err) + assert.False(t, result.IsError) + assert.Empty(t, backend.getLastAuthHeader(), "A2A backend should receive no Authorization header when the client sends none") +}