Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions cmd/xsql/command_unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
"path/filepath"
"testing"

"github.com/modelcontextprotocol/go-sdk/mcp"

"github.com/zx06/xsql/internal/app"
"github.com/zx06/xsql/internal/config"
"github.com/zx06/xsql/internal/errors"
Expand Down Expand Up @@ -345,7 +347,7 @@
unsafe_allow_write: true
`
if err := os.WriteFile(configPath, []byte(configContent), 0644); err != nil {
t.Fatalf("failed to write config: %v", err)

Check failure on line 350 in cmd/xsql/command_unit_test.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Define a constant instead of duplicating this literal "failed to write config: %v" 4 times.

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ1xgrVSbfjn6n6i3tFM&open=AZ1xgrVSbfjn6n6i3tFM&pullRequest=41
}

GlobalConfig.ConfigStr = configPath
Expand Down Expand Up @@ -381,6 +383,51 @@
}
}

func TestRunMCPServer_StdioTreatsContextCanceledAsCleanExit(t *testing.T) {
prevRun := runMCPStdioServer
runMCPStdioServer = func(ctx context.Context, _ *mcp.Server) error {
cancelCtx, cancel := context.WithCancel(ctx)
cancel()
return cancelCtx.Err()
}
defer func() {
runMCPStdioServer = prevRun
}()

configPath := filepath.Join(t.TempDir(), "xsql.yaml")
if err := os.WriteFile(configPath, []byte("profiles: {}\nssh_proxies: {}\n"), 0644); err != nil {

Check failure on line 398 in cmd/xsql/command_unit_test.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Define a constant instead of duplicating this literal "profiles: {}\nssh_proxies: {}\n" 4 times.

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ1xgrVSbfjn6n6i3tFN&open=AZ1xgrVSbfjn6n6i3tFN&pullRequest=41
t.Fatalf("failed to write config: %v", err)
}

GlobalConfig.ConfigStr = configPath
err := runMCPServer(&mcpServerOptions{})
if err != nil {
t.Fatalf("expected nil error for canceled stdio server, got %v", err)
}
}

func TestRunMCPServer_StdioPropagatesNonCanceledError(t *testing.T) {
prevRun := runMCPStdioServer
wantErr := context.DeadlineExceeded
runMCPStdioServer = func(ctx context.Context, _ *mcp.Server) error {
return wantErr
}
defer func() {
runMCPStdioServer = prevRun
}()

configPath := filepath.Join(t.TempDir(), "xsql.yaml")
if err := os.WriteFile(configPath, []byte("profiles: {}\nssh_proxies: {}\n"), 0644); err != nil {
t.Fatalf("failed to write config: %v", err)
}

GlobalConfig.ConfigStr = configPath
err := runMCPServer(&mcpServerOptions{})
if err != wantErr {
t.Fatalf("expected %v, got %v", wantErr, err)
}
}

func TestResolveMCPServerOptions_Defaults(t *testing.T) {
cfg := config.File{
Profiles: map[string]config.Profile{},
Expand Down
41 changes: 35 additions & 6 deletions cmd/xsql/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import (
"context"
stderrors "errors"
"log"
"net/http"
"os"
"os/signal"
Expand All @@ -17,6 +19,10 @@
"github.com/zx06/xsql/internal/secret"
)

var runMCPStdioServer = func(ctx context.Context, server *mcp.Server) error {
return server.Run(ctx, &mcp.StdioTransport{})
}

// NewMCPCommand creates the MCP command group
func NewMCPCommand() *cobra.Command {
mcpCmd := &cobra.Command{
Expand Down Expand Up @@ -49,7 +55,7 @@
}

// runMCPServer runs the MCP server
func runMCPServer(opts *mcpServerOptions) error {

Check failure on line 58 in cmd/xsql/mcp.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Refactor this method to reduce its Cognitive Complexity from 20 to the 15 allowed.

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ1xco8jhKIb2OI0PIhc&open=AZ1xco8jhKIb2OI0PIhc&pullRequest=41
// Load config
cfg, _, xe := config.LoadConfig(config.Options{
ConfigPath: GlobalConfig.ConfigStr,
Expand All @@ -75,8 +81,22 @@

switch resolved.transport {
case mcp_pkg.TransportStdio:
ctx := context.Background()
return server.Run(ctx, &mcp.StdioTransport{})
// Install signal handler for graceful shutdown in stdio mode
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
signal.Stop(sigChan)
cancel()
}()

if err := runMCPStdioServer(ctx, server); err != nil && !stderrors.Is(err, context.Canceled) {
return err
}
return nil
case mcp_pkg.TransportStreamableHTTP:
handler, err := mcp_pkg.NewStreamableHTTPHandler(server, resolved.httpAuthToken)
if err != nil {
Expand All @@ -86,21 +106,30 @@
return errors.Wrap(errors.CodeInternal, "failed to create streamable http handler", nil, err)
}
httpServer := &http.Server{
Addr: resolved.httpAddr,
Handler: handler,
Addr: resolved.httpAddr,
Handler: handler,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
}

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)

go func() {
<-sigChan
signal.Stop(sigChan)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
_ = httpServer.Shutdown(ctx)
if shutdownErr := httpServer.Shutdown(ctx); shutdownErr != nil {
log.Printf("[mcp] http server shutdown error: %v", shutdownErr)
}
}()

return httpServer.ListenAndServe()
if listenErr := httpServer.ListenAndServe(); listenErr != nil && listenErr != http.ErrServerClosed {
return listenErr
}
return nil
default:
return errors.New(errors.CodeCfgInvalid, "unsupported mcp transport", map[string]any{"transport": resolved.transport})
}
Expand Down
11 changes: 10 additions & 1 deletion internal/app/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package app
import (
"context"
"database/sql"
"sync"

"github.com/zx06/xsql/internal/config"
"github.com/zx06/xsql/internal/db"
Expand All @@ -15,6 +16,7 @@ type Connection struct {
DB *sql.DB
SSHClient *ssh.Client
Profile config.Profile
closeMu sync.Mutex
closeHooks []func()
}

Expand All @@ -30,7 +32,11 @@ func (c *Connection) Close() error {
errs = append(errs, err)
}
}
for _, fn := range c.closeHooks {
c.closeMu.Lock()
hooks := c.closeHooks
c.closeHooks = nil
c.closeMu.Unlock()
for _, fn := range hooks {
if fn != nil {
fn()
}
Expand Down Expand Up @@ -81,6 +87,7 @@ func ResolveConnection(ctx context.Context, opts ConnectionOptions) (*Connection
}

closeHooks := make([]func(), 0, 1)
var hooksMu sync.Mutex
connOpts := db.ConnOptions{
DSN: opts.Profile.DSN,
Host: opts.Profile.Host,
Expand All @@ -90,7 +97,9 @@ func ResolveConnection(ctx context.Context, opts ConnectionOptions) (*Connection
Database: opts.Profile.Database,
RegisterCloseHook: func(fn func()) {
if fn != nil {
hooksMu.Lock()
closeHooks = append(closeHooks, fn)
hooksMu.Unlock()
}
},
}
Expand Down
45 changes: 25 additions & 20 deletions internal/db/mysql/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ import (
)

var (
dialerCounter uint64
dialers sync.Map
registeredDials sync.Map
dialerCounter uint64
dialers sync.Map
registerDialContextFn = mysql.RegisterDialContext
deregisterDialContextFn = mysql.DeregisterDialContext
)

func init() {
Expand All @@ -30,24 +31,30 @@ func registerDialContext(dialer func(context.Context, string, string) (net.Conn,
dialerNum := atomic.AddUint64(&dialerCounter, 1)
dialName := fmt.Sprintf("xsql_ssh_tunnel_%d", dialerNum)

if _, loaded := registeredDials.LoadOrStore(dialName, true); !loaded {
mysql.RegisterDialContext(dialName, func(ctx context.Context, addr string) (net.Conn, error) {
d, ok := dialers.Load(dialName)
if !ok {
return nil, fmt.Errorf("dialer not found: %s", dialName)
}
fn, ok := d.(func(context.Context, string, string) (net.Conn, error))
if !ok || fn == nil {
return nil, fmt.Errorf("invalid dialer for: %s", dialName)
}
return fn(ctx, "tcp", addr)
})
}
registerDialContextFn(dialName, func(ctx context.Context, addr string) (net.Conn, error) {
d, ok := dialers.Load(dialName)
if !ok {
return nil, fmt.Errorf("dialer not found: %s", dialName)
}
fn, ok := d.(func(context.Context, string, string) (net.Conn, error))
if !ok || fn == nil {
return nil, fmt.Errorf("invalid dialer for: %s", dialName)
}
return fn(ctx, "tcp", addr)
})

dialers.Store(dialName, dialer)
return dialName
}

func cleanupDialContext(dialName string) {
if dialName == "" {
return
}
dialers.Delete(dialName)
deregisterDialContextFn(dialName)
}

type Driver struct{}

func (d *Driver) Open(ctx context.Context, opts db.ConnOptions) (*sql.DB, *errors.XError) {
Expand Down Expand Up @@ -80,7 +87,7 @@ func (d *Driver) Open(ctx context.Context, opts db.ConnOptions) (*sql.DB, *error
cfg.Net = dialName
if opts.RegisterCloseHook != nil {
opts.RegisterCloseHook(func() {
dialers.Delete(dialName)
cleanupDialContext(dialName)
})
}
}
Expand All @@ -94,9 +101,7 @@ func (d *Driver) Open(ctx context.Context, opts db.ConnOptions) (*sql.DB, *error
if closeErr := conn.Close(); closeErr != nil {
log.Printf("failed to close mysql connection: %v", closeErr)
}
if dialName != "" {
dialers.Delete(dialName)
}
cleanupDialContext(dialName)
return nil, errors.Wrap(errors.CodeDBConnectFailed, "failed to ping mysql", nil, err)
}
return conn, nil
Expand Down
17 changes: 16 additions & 1 deletion internal/db/mysql/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -197,7 +198,15 @@ func TestDriver_Open_ContextCancelled(t *testing.T) {

func TestDriver_Open_WithDialer_CleanupOnFailure(t *testing.T) {
resetSyncMap(&dialers)
resetSyncMap(&registeredDials)
var deregisterCalls int32
prevDeregister := deregisterDialContextFn
deregisterDialContextFn = func(net string) {
atomic.AddInt32(&deregisterCalls, 1)
prevDeregister(net)
}
defer func() {
deregisterDialContextFn = prevDeregister
}()

drv, _ := db.Get("mysql")
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
Expand Down Expand Up @@ -227,11 +236,17 @@ func TestDriver_Open_WithDialer_CleanupOnFailure(t *testing.T) {
if countSyncMap(&dialers) != 0 {
t.Fatal("expected dialers map to be cleaned on open failure")
}
if got := atomic.LoadInt32(&deregisterCalls); got != 1 {
t.Fatalf("expected one deregister call on open failure, got %d", got)
}

hooks[0]()
if countSyncMap(&dialers) != 0 {
t.Fatal("expected hook cleanup to be idempotent")
}
if got := atomic.LoadInt32(&deregisterCalls); got != 2 {
t.Fatalf("expected close hook cleanup to remain safe, got %d deregister calls", got)
}
}

func countSyncMap(m *sync.Map) int {
Expand Down
6 changes: 6 additions & 0 deletions internal/mcp/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ type ToolHandler struct {

// NewToolHandler creates a new tool handler
func NewToolHandler(cfg *config.File) *ToolHandler {
if cfg == nil {
cfg = &config.File{
Profiles: map[string]config.Profile{},
SSHProxies: map[string]config.SSHProxy{},
}
}
return &ToolHandler{
config: cfg,
}
Expand Down
29 changes: 26 additions & 3 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,22 @@ func (p *Proxy) handleConnection(localConn net.Conn, remoteAddr string) {

// Bidirectional copy
var wg sync.WaitGroup
wg.Add(2)
errChan := make(chan error, 2)

wg.Add(1)
go func() {
defer wg.Done()
_, _ = io.Copy(localConn, remoteConn)
if _, err := io.Copy(localConn, remoteConn); err != nil {
errChan <- fmt.Errorf("copy remote->local failed: %w", err)
}
}()

wg.Add(1)
go func() {
defer wg.Done()
_, _ = io.Copy(remoteConn, localConn)
if _, err := io.Copy(remoteConn, localConn); err != nil {
errChan <- fmt.Errorf("copy local->remote failed: %w", err)
}
}()

// Wait for both copies to finish or context cancellation
Expand All @@ -172,7 +178,24 @@ func (p *Proxy) handleConnection(localConn net.Conn, remoteAddr string) {

select {
case <-done:
// Check if there were any copy errors
select {
case err := <-errChan:
log.Printf("[proxy] connection copy error: %v", err)
default:
}
case <-p.ctx.Done():
// Context cancelled: close connections to unblock io.Copy goroutines
_ = localConn.Close()
_ = remoteConn.Close()
// Wait for goroutines to finish
<-done
// Check for any final errors
select {
case err := <-errChan:
log.Printf("[proxy] connection copy error on shutdown: %v", err)
default:
}
}
}

Expand Down
Loading
Loading