From 259694698f9cada5a8b5c1e05357d3593c7161c4 Mon Sep 17 00:00:00 2001 From: zx06 <12474586+zx06@users.noreply.github.com> Date: Thu, 9 Apr 2026 15:41:11 +0800 Subject: [PATCH 1/3] fix: resolve resource leaks, goroutine leaks, and signal handling issues - Fix MySQL registeredDials sync.Map entry leak by adding proper cleanup with DeregisterDialContext in close hook - Fix goroutine leak in proxy handleConnection by closing connections and waiting for io.Copy goroutines on context cancellation - Add signal handler for stdio MCP mode to enable graceful shutdown of active database connections and SSH tunnels - Add nil config check in NewToolHandler to prevent panic - Fix silent error swallowing in proxy io.Copy operations with logging - Add HTTP server timeouts and proper shutdown error handling in MCP - Fix potential race condition in closeHooks slice with mutex protection All changes pass go vet, golangci-lint, and test suite with race detector. Co-authored-by: Qwen-Coder --- cmd/xsql/mcp.go | 31 ++++++++++++++++++++++++++----- internal/app/conn.go | 11 ++++++++++- internal/db/mysql/driver.go | 30 ++++++++++++++---------------- internal/db/mysql/driver_test.go | 1 - internal/mcp/tools.go | 6 ++++++ internal/proxy/proxy.go | 29 ++++++++++++++++++++++++++--- 6 files changed, 82 insertions(+), 26 deletions(-) diff --git a/cmd/xsql/mcp.go b/cmd/xsql/mcp.go index d47bad1..4797a1d 100644 --- a/cmd/xsql/mcp.go +++ b/cmd/xsql/mcp.go @@ -2,6 +2,7 @@ package main import ( "context" + "log" "net/http" "os" "os/signal" @@ -75,7 +76,18 @@ func runMCPServer(opts *mcpServerOptions) error { switch resolved.transport { case mcp_pkg.TransportStdio: - ctx := context.Background() + // Install signal handler for graceful shutdown in stdio mode + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigChan + signal.Stop(sigChan) + cancel() + }() + return server.Run(ctx, &mcp.StdioTransport{}) case mcp_pkg.TransportStreamableHTTP: handler, err := mcp_pkg.NewStreamableHTTPHandler(server, resolved.httpAuthToken) @@ -86,8 +98,11 @@ func runMCPServer(opts *mcpServerOptions) error { return errors.Wrap(errors.CodeInternal, "failed to create streamable http handler", nil, err) } httpServer := &http.Server{ - Addr: resolved.httpAddr, - Handler: handler, + Addr: resolved.httpAddr, + Handler: handler, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, } sigChan := make(chan os.Signal, 1) @@ -95,12 +110,18 @@ func runMCPServer(opts *mcpServerOptions) error { go func() { <-sigChan + signal.Stop(sigChan) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - _ = httpServer.Shutdown(ctx) + if shutdownErr := httpServer.Shutdown(ctx); shutdownErr != nil { + log.Printf("[mcp] http server shutdown error: %v", shutdownErr) + } }() - return httpServer.ListenAndServe() + if listenErr := httpServer.ListenAndServe(); listenErr != nil && listenErr != http.ErrServerClosed { + return listenErr + } + return nil default: return errors.New(errors.CodeCfgInvalid, "unsupported mcp transport", map[string]any{"transport": resolved.transport}) } diff --git a/internal/app/conn.go b/internal/app/conn.go index 48175ad..9e1496c 100644 --- a/internal/app/conn.go +++ b/internal/app/conn.go @@ -3,6 +3,7 @@ package app import ( "context" "database/sql" + "sync" "github.com/zx06/xsql/internal/config" "github.com/zx06/xsql/internal/db" @@ -15,6 +16,7 @@ type Connection struct { DB *sql.DB SSHClient *ssh.Client Profile config.Profile + closeMu sync.Mutex closeHooks []func() } @@ -30,7 +32,11 @@ func (c *Connection) Close() error { errs = append(errs, err) } } - for _, fn := range c.closeHooks { + c.closeMu.Lock() + hooks := c.closeHooks + c.closeHooks = nil + c.closeMu.Unlock() + for _, fn := range hooks { if fn != nil { fn() } @@ -81,6 +87,7 @@ func ResolveConnection(ctx context.Context, opts ConnectionOptions) (*Connection } closeHooks := make([]func(), 0, 1) + var hooksMu sync.Mutex connOpts := db.ConnOptions{ DSN: opts.Profile.DSN, Host: opts.Profile.Host, @@ -90,7 +97,9 @@ func ResolveConnection(ctx context.Context, opts ConnectionOptions) (*Connection Database: opts.Profile.Database, RegisterCloseHook: func(fn func()) { if fn != nil { + hooksMu.Lock() closeHooks = append(closeHooks, fn) + hooksMu.Unlock() } }, } diff --git a/internal/db/mysql/driver.go b/internal/db/mysql/driver.go index 43ce78a..0687fb3 100644 --- a/internal/db/mysql/driver.go +++ b/internal/db/mysql/driver.go @@ -17,9 +17,8 @@ import ( ) var ( - dialerCounter uint64 - dialers sync.Map - registeredDials sync.Map + dialerCounter uint64 + dialers sync.Map ) func init() { @@ -30,19 +29,17 @@ func registerDialContext(dialer func(context.Context, string, string) (net.Conn, dialerNum := atomic.AddUint64(&dialerCounter, 1) dialName := fmt.Sprintf("xsql_ssh_tunnel_%d", dialerNum) - if _, loaded := registeredDials.LoadOrStore(dialName, true); !loaded { - mysql.RegisterDialContext(dialName, func(ctx context.Context, addr string) (net.Conn, error) { - d, ok := dialers.Load(dialName) - if !ok { - return nil, fmt.Errorf("dialer not found: %s", dialName) - } - fn, ok := d.(func(context.Context, string, string) (net.Conn, error)) - if !ok || fn == nil { - return nil, fmt.Errorf("invalid dialer for: %s", dialName) - } - return fn(ctx, "tcp", addr) - }) - } + mysql.RegisterDialContext(dialName, func(ctx context.Context, addr string) (net.Conn, error) { + d, ok := dialers.Load(dialName) + if !ok { + return nil, fmt.Errorf("dialer not found: %s", dialName) + } + fn, ok := d.(func(context.Context, string, string) (net.Conn, error)) + if !ok || fn == nil { + return nil, fmt.Errorf("invalid dialer for: %s", dialName) + } + return fn(ctx, "tcp", addr) + }) dialers.Store(dialName, dialer) return dialName @@ -81,6 +78,7 @@ func (d *Driver) Open(ctx context.Context, opts db.ConnOptions) (*sql.DB, *error if opts.RegisterCloseHook != nil { opts.RegisterCloseHook(func() { dialers.Delete(dialName) + mysql.DeregisterDialContext(dialName) }) } } diff --git a/internal/db/mysql/driver_test.go b/internal/db/mysql/driver_test.go index 56a672f..3d47386 100644 --- a/internal/db/mysql/driver_test.go +++ b/internal/db/mysql/driver_test.go @@ -197,7 +197,6 @@ func TestDriver_Open_ContextCancelled(t *testing.T) { func TestDriver_Open_WithDialer_CleanupOnFailure(t *testing.T) { resetSyncMap(&dialers) - resetSyncMap(®isteredDials) drv, _ := db.Get("mysql") ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) diff --git a/internal/mcp/tools.go b/internal/mcp/tools.go index 8c11c5c..2833ec0 100644 --- a/internal/mcp/tools.go +++ b/internal/mcp/tools.go @@ -34,6 +34,12 @@ type ToolHandler struct { // NewToolHandler creates a new tool handler func NewToolHandler(cfg *config.File) *ToolHandler { + if cfg == nil { + cfg = &config.File{ + Profiles: map[string]config.Profile{}, + SSHProxies: map[string]config.SSHProxy{}, + } + } return &ToolHandler{ config: cfg, } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index bdd4974..24cf64e 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -151,16 +151,22 @@ func (p *Proxy) handleConnection(localConn net.Conn, remoteAddr string) { // Bidirectional copy var wg sync.WaitGroup - wg.Add(2) + errChan := make(chan error, 2) + wg.Add(1) go func() { defer wg.Done() - _, _ = io.Copy(localConn, remoteConn) + if _, err := io.Copy(localConn, remoteConn); err != nil { + errChan <- fmt.Errorf("copy remote->local failed: %w", err) + } }() + wg.Add(1) go func() { defer wg.Done() - _, _ = io.Copy(remoteConn, localConn) + if _, err := io.Copy(remoteConn, localConn); err != nil { + errChan <- fmt.Errorf("copy local->remote failed: %w", err) + } }() // Wait for both copies to finish or context cancellation @@ -172,7 +178,24 @@ func (p *Proxy) handleConnection(localConn net.Conn, remoteAddr string) { select { case <-done: + // Check if there were any copy errors + select { + case err := <-errChan: + log.Printf("[proxy] connection copy error: %v", err) + default: + } case <-p.ctx.Done(): + // Context cancelled: close connections to unblock io.Copy goroutines + _ = localConn.Close() + _ = remoteConn.Close() + // Wait for goroutines to finish + <-done + // Check for any final errors + select { + case err := <-errChan: + log.Printf("[proxy] connection copy error on shutdown: %v", err) + default: + } } } From 609c13aacb308e0bd4ea4ceadcc0edf4ebfa5e7d Mon Sep 17 00:00:00 2001 From: zx06 <12474586+zx06@users.noreply.github.com> Date: Thu, 9 Apr 2026 17:10:20 +0800 Subject: [PATCH 2/3] fix: handle mcp stdio shutdown and mysql dialer cleanup --- cmd/xsql/command_unit_test.go | 25 +++++++++++++++++++++++++ cmd/xsql/mcp.go | 10 +++++++++- internal/db/mysql/driver.go | 23 +++++++++++++++-------- internal/db/mysql/driver_test.go | 16 ++++++++++++++++ 4 files changed, 65 insertions(+), 9 deletions(-) diff --git a/cmd/xsql/command_unit_test.go b/cmd/xsql/command_unit_test.go index f904cba..0a0550b 100644 --- a/cmd/xsql/command_unit_test.go +++ b/cmd/xsql/command_unit_test.go @@ -8,6 +8,8 @@ import ( "path/filepath" "testing" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/zx06/xsql/internal/app" "github.com/zx06/xsql/internal/config" "github.com/zx06/xsql/internal/errors" @@ -381,6 +383,29 @@ func TestRunMCPServer_ConfigMissing(t *testing.T) { } } +func TestRunMCPServer_StdioTreatsContextCanceledAsCleanExit(t *testing.T) { + prevRun := runMCPStdioServer + runMCPStdioServer = func(ctx context.Context, _ *mcp.Server) error { + cancelCtx, cancel := context.WithCancel(ctx) + cancel() + return cancelCtx.Err() + } + defer func() { + runMCPStdioServer = prevRun + }() + + configPath := filepath.Join(t.TempDir(), "xsql.yaml") + if err := os.WriteFile(configPath, []byte("profiles: {}\nssh_proxies: {}\n"), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + GlobalConfig.ConfigStr = configPath + err := runMCPServer(&mcpServerOptions{}) + if err != nil { + t.Fatalf("expected nil error for canceled stdio server, got %v", err) + } +} + func TestResolveMCPServerOptions_Defaults(t *testing.T) { cfg := config.File{ Profiles: map[string]config.Profile{}, diff --git a/cmd/xsql/mcp.go b/cmd/xsql/mcp.go index 4797a1d..932fa97 100644 --- a/cmd/xsql/mcp.go +++ b/cmd/xsql/mcp.go @@ -2,6 +2,7 @@ package main import ( "context" + stderrors "errors" "log" "net/http" "os" @@ -18,6 +19,10 @@ import ( "github.com/zx06/xsql/internal/secret" ) +var runMCPStdioServer = func(ctx context.Context, server *mcp.Server) error { + return server.Run(ctx, &mcp.StdioTransport{}) +} + // NewMCPCommand creates the MCP command group func NewMCPCommand() *cobra.Command { mcpCmd := &cobra.Command{ @@ -88,7 +93,10 @@ func runMCPServer(opts *mcpServerOptions) error { cancel() }() - return server.Run(ctx, &mcp.StdioTransport{}) + if err := runMCPStdioServer(ctx, server); err != nil && !stderrors.Is(err, context.Canceled) { + return err + } + return nil case mcp_pkg.TransportStreamableHTTP: handler, err := mcp_pkg.NewStreamableHTTPHandler(server, resolved.httpAuthToken) if err != nil { diff --git a/internal/db/mysql/driver.go b/internal/db/mysql/driver.go index 0687fb3..52c8d28 100644 --- a/internal/db/mysql/driver.go +++ b/internal/db/mysql/driver.go @@ -17,8 +17,10 @@ import ( ) var ( - dialerCounter uint64 - dialers sync.Map + dialerCounter uint64 + dialers sync.Map + registerDialContextFn = mysql.RegisterDialContext + deregisterDialContextFn = mysql.DeregisterDialContext ) func init() { @@ -29,7 +31,7 @@ func registerDialContext(dialer func(context.Context, string, string) (net.Conn, dialerNum := atomic.AddUint64(&dialerCounter, 1) dialName := fmt.Sprintf("xsql_ssh_tunnel_%d", dialerNum) - mysql.RegisterDialContext(dialName, func(ctx context.Context, addr string) (net.Conn, error) { + registerDialContextFn(dialName, func(ctx context.Context, addr string) (net.Conn, error) { d, ok := dialers.Load(dialName) if !ok { return nil, fmt.Errorf("dialer not found: %s", dialName) @@ -45,6 +47,14 @@ func registerDialContext(dialer func(context.Context, string, string) (net.Conn, return dialName } +func cleanupDialContext(dialName string) { + if dialName == "" { + return + } + dialers.Delete(dialName) + deregisterDialContextFn(dialName) +} + type Driver struct{} func (d *Driver) Open(ctx context.Context, opts db.ConnOptions) (*sql.DB, *errors.XError) { @@ -77,8 +87,7 @@ func (d *Driver) Open(ctx context.Context, opts db.ConnOptions) (*sql.DB, *error cfg.Net = dialName if opts.RegisterCloseHook != nil { opts.RegisterCloseHook(func() { - dialers.Delete(dialName) - mysql.DeregisterDialContext(dialName) + cleanupDialContext(dialName) }) } } @@ -92,9 +101,7 @@ func (d *Driver) Open(ctx context.Context, opts db.ConnOptions) (*sql.DB, *error if closeErr := conn.Close(); closeErr != nil { log.Printf("failed to close mysql connection: %v", closeErr) } - if dialName != "" { - dialers.Delete(dialName) - } + cleanupDialContext(dialName) return nil, errors.Wrap(errors.CodeDBConnectFailed, "failed to ping mysql", nil, err) } return conn, nil diff --git a/internal/db/mysql/driver_test.go b/internal/db/mysql/driver_test.go index 3d47386..1ab0641 100644 --- a/internal/db/mysql/driver_test.go +++ b/internal/db/mysql/driver_test.go @@ -4,6 +4,7 @@ import ( "context" "net" "sync" + "sync/atomic" "testing" "time" @@ -197,6 +198,15 @@ func TestDriver_Open_ContextCancelled(t *testing.T) { func TestDriver_Open_WithDialer_CleanupOnFailure(t *testing.T) { resetSyncMap(&dialers) + var deregisterCalls int32 + prevDeregister := deregisterDialContextFn + deregisterDialContextFn = func(net string) { + atomic.AddInt32(&deregisterCalls, 1) + prevDeregister(net) + } + defer func() { + deregisterDialContextFn = prevDeregister + }() drv, _ := db.Get("mysql") ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) @@ -226,11 +236,17 @@ func TestDriver_Open_WithDialer_CleanupOnFailure(t *testing.T) { if countSyncMap(&dialers) != 0 { t.Fatal("expected dialers map to be cleaned on open failure") } + if got := atomic.LoadInt32(&deregisterCalls); got != 1 { + t.Fatalf("expected one deregister call on open failure, got %d", got) + } hooks[0]() if countSyncMap(&dialers) != 0 { t.Fatal("expected hook cleanup to be idempotent") } + if got := atomic.LoadInt32(&deregisterCalls); got != 2 { + t.Fatalf("expected close hook cleanup to remain safe, got %d deregister calls", got) + } } func countSyncMap(m *sync.Map) int { From 70af99accba26009b00000a710006f809c139ae7 Mon Sep 17 00:00:00 2001 From: zx06 <12474586+zx06@users.noreply.github.com> Date: Thu, 9 Apr 2026 17:27:54 +0800 Subject: [PATCH 3/3] test: cover mcp stdio shutdown behavior --- cmd/xsql/command_unit_test.go | 22 ++++++++++++++++ tests/e2e/mcp_test.go | 49 +++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/cmd/xsql/command_unit_test.go b/cmd/xsql/command_unit_test.go index 0a0550b..d8b8e3c 100644 --- a/cmd/xsql/command_unit_test.go +++ b/cmd/xsql/command_unit_test.go @@ -406,6 +406,28 @@ func TestRunMCPServer_StdioTreatsContextCanceledAsCleanExit(t *testing.T) { } } +func TestRunMCPServer_StdioPropagatesNonCanceledError(t *testing.T) { + prevRun := runMCPStdioServer + wantErr := context.DeadlineExceeded + runMCPStdioServer = func(ctx context.Context, _ *mcp.Server) error { + return wantErr + } + defer func() { + runMCPStdioServer = prevRun + }() + + configPath := filepath.Join(t.TempDir(), "xsql.yaml") + if err := os.WriteFile(configPath, []byte("profiles: {}\nssh_proxies: {}\n"), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + GlobalConfig.ConfigStr = configPath + err := runMCPServer(&mcpServerOptions{}) + if err != wantErr { + t.Fatalf("expected %v, got %v", wantErr, err) + } +} + func TestResolveMCPServerOptions_Defaults(t *testing.T) { cfg := config.File{ Profiles: map[string]config.Profile{}, diff --git a/tests/e2e/mcp_test.go b/tests/e2e/mcp_test.go index dc156fa..89da464 100644 --- a/tests/e2e/mcp_test.go +++ b/tests/e2e/mcp_test.go @@ -8,7 +8,11 @@ import ( "os" "os/exec" "path/filepath" + "runtime" + "strings" + "syscall" "testing" + "time" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -234,6 +238,51 @@ func TestMCPServer_EmptyConfig(t *testing.T) { } } +func TestMCPServer_SIGINTCleanExit(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("requires POSIX signals") + } + + config := createTempConfig(t, `profiles: {}`) + + cmd := exec.Command(testBinary, "mcp", "server", "--config", config) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start MCP server: %v", err) + } + + time.Sleep(200 * time.Millisecond) + + if err := cmd.Process.Signal(syscall.SIGINT); err != nil { + t.Fatalf("failed to send SIGINT: %v", err) + } + + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + select { + case err := <-done: + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + t.Fatalf("expected clean exit, got exit code %d, stderr: %s", exitErr.ExitCode(), stderr.String()) + } + t.Fatalf("expected clean exit, got %v", err) + } + case <-time.After(5 * time.Second): + _ = cmd.Process.Kill() + t.Fatal("mcp server did not exit after SIGINT") + } + + if strings.Contains(stderr.String(), "context canceled") { + t.Fatalf("stderr should not contain context canceled, got: %s", stderr.String()) + } +} + // listMCPTools lists all available MCP tools func listMCPTools(t *testing.T, configPath string) []mcp.Tool { t.Helper()