diff --git a/exec/cmd.go b/exec/cmd.go new file mode 100644 index 00000000..324fe8fb --- /dev/null +++ b/exec/cmd.go @@ -0,0 +1,402 @@ +package exec + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/netip" + "os" + "os/exec" + "path/filepath" + "time" + + "github.com/crossplane/crossplane-runtime/pkg/resource" + "github.com/mattn/go-isatty" + meta "github.com/ninech/apis/meta/v1alpha1" + "github.com/ninech/nctl/api" + "github.com/ninech/nctl/get" + "github.com/ninech/nctl/internal/cli" + "github.com/ninech/nctl/internal/format" + "github.com/ninech/nctl/internal/ipcheck" +) + +// cmdExecutor encapsulates resource-specific logic for connecting via an external CLI. +type cmdExecutor[T resource.Managed] interface { + // Command returns the CLI binary name (e.g. "psql", "mysql", "redis-cli"). + // Used for the early path check before any credential fetching. + Command() string + + // Endpoint returns "host:port" for the TCP connectivity check. + Endpoint(res T) string + + // NewCmd builds the *exec.Cmd for connecting to res with the given credentials. + // Env is set on the returned Cmd; stdio and ExtraArgs are wired by connectAndExec. + // The returned cleanup func removes any temp files created (e.g. CA cert, options file). + NewCmd(ctx context.Context, res T, user, pw string) (cmd *exec.Cmd, cleanup func(), err error) +} + +// accessManager extends cmdExecutor for resources that have access restrictions. +type accessManager[T resource.Managed] interface { + // AllowedCIDRs returns the current list of allowed CIDRs for the resource. + AllowedCIDRs(res T) []meta.IPv4CIDR + + // Update patches the resource to allow the given CIDRs. + Update(ctx context.Context, client *api.Client, res T, cidrs []meta.IPv4CIDR) error +} + +// serviceCmd is the shared base for all database exec sub-commands. +type serviceCmd struct { + resourceCmd + format.Writer `kong:"-"` + format.Reader `kong:"-"` + AllowedCidrs *[]meta.IPv4CIDR `placeholder:"203.0.113.1/32" help:"Specifies the IP addresses allowed to connect to the instance. Overrides auto-detected public IP."` + WaitTimeout time.Duration `default:"3m" help:"Timeout waiting for connectivity."` + ExtraArgs []string `arg:"" optional:"" passthrough:"" help:"Additional flags passed to the CLI (after --)."` + + // Internal dependencies — nil means use production default. + runCommand func(cmd *exec.Cmd) error `kong:"-"` + lookPath func(file string) (string, error) `kong:"-"` + waitForConnectivity func(ctx context.Context, writer format.Writer, endpoint string, timeout time.Duration) error `kong:"-"` + openTTYForConfirm func() (io.ReadCloser, error) `kong:"-"` +} + +// BeforeApply initializes Writer and Reader from Kong's bound io.Writer and io.Reader. +func (cmd *serviceCmd) BeforeApply(writer io.Writer, reader io.Reader) error { + return errors.Join( + cmd.Writer.BeforeApply(writer), + cmd.Reader.BeforeApply(reader), + ) +} + +func (cmd serviceCmd) getRunCommand() func(*exec.Cmd) error { + if cmd.runCommand != nil { + return cmd.runCommand + } + return func(c *exec.Cmd) error { + return c.Run() + } +} + +func (cmd serviceCmd) getLookPath() func(string) (string, error) { + if cmd.lookPath != nil { + return cmd.lookPath + } + + return exec.LookPath +} + +func (cmd serviceCmd) connectivityCheck() func(context.Context, format.Writer, string, time.Duration) error { + if cmd.waitForConnectivity != nil { + return cmd.waitForConnectivity + } + + return waitForConnectivity +} + +// openTTY returns the openTTY function to use for confirming prompts. +func (cmd serviceCmd) openTTY() func() (io.ReadCloser, error) { + if cmd.openTTYForConfirm != nil { + return cmd.openTTYForConfirm + } + + return func() (io.ReadCloser, error) { + return os.Open("/dev/tty") + } +} + +// connectAndExec is the main orchestration function for exec commands. +// It handles path checking, connectivity waiting, and credential retrieval. +func connectAndExec[T resource.Managed]( + ctx context.Context, + client *api.Client, + res T, + connector cmdExecutor[T], + opts serviceCmd, +) error { + if err := opts.checkPath(connector.Command()); err != nil { + return err + } + + endpoint := connector.Endpoint(res) + if endpoint == "" { + return fmt.Errorf("resource %q is not ready yet (no endpoint available)", res.GetName()) + } + + if !quickDial(ctx, endpoint) { + if am, ok := connector.(accessManager[T]); ok { + if err := ensureAccess(ctx, client, am, res, opts); err != nil { + return err + } + } + + if err := opts.connectivityCheck()(ctx, opts.Writer, endpoint, opts.WaitTimeout); err != nil { + return err + } + } + + user, pw, err := getCredentials(ctx, client, res) + if err != nil { + return err + } + + cmd, cleanup, err := connector.NewCmd(ctx, res, user, pw) + if err != nil { + return fmt.Errorf("building CLI command: %w", err) + } + defer cleanup() + + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Env = append(os.Environ(), cmd.Env...) + cmd.Args = append(cmd.Args, opts.ExtraArgs...) + + if err := opts.getRunCommand()(cmd); err != nil { + if exitErr, ok := errors.AsType[*exec.ExitError](err); ok { + return cli.ErrorWithContext(err).WithExitCode(exitErr.ExitCode()) + } + return err + } + + return nil +} + +// ensureAccess detects the caller's public IP (or uses the overridden list), +// checks whether it is already permitted, and if not prompts the user before +// calling connector.Update. +func ensureAccess[T resource.Managed]( + ctx context.Context, + client *api.Client, + connector accessManager[T], + res T, + cmd serviceCmd, +) error { + var toAdd []meta.IPv4CIDR + + if cmd.AllowedCidrs != nil { + toAdd = *cmd.AllowedCidrs + + if cidrsPresent(connector.AllowedCIDRs(res), toAdd) { + cmd.Infof("✅", "specified CIDRs are already allowed") + return nil + } + } else { + ip, err := ipcheck.New(ipcheck.WithUserAgent(cli.Name)).PublicIP(ctx) + if err != nil { + return cli.ErrorWithContext(fmt.Errorf("detecting public IP address: %w", err)). + WithSuggestions("Are you connected to the internet?") + } + if ip.Blocked { + return cli.ErrorWithContext(fmt.Errorf("public IP seems to be blocked")). + WithContext("IP", ip.RemoteAddr.String()). + WithSuggestions("Reach out to support@nine.ch.") + } + cmd.Infof("🌐", "detected public IP: %s", ip.RemoteAddr) + + if cidr := ipCoveredByCIDRs(ip.RemoteAddr, connector.AllowedCIDRs(res)); cidr != nil { + cmd.Infof("✅", "IP %s is already allowedby %s", ip.RemoteAddr, cidr.String()) + return nil + } + + toAdd = []meta.IPv4CIDR{meta.IPv4CIDR(netip.PrefixFrom(ip.RemoteAddr, 32).String())} + } + + msg := fmt.Sprintf("Add %v to the allowed CIDRs of %q?", toAdd, res.GetName()) + ok, err := cmd.confirm(msg) + if err != nil { + return err + } + if !ok { + return fmt.Errorf("CIDR addition canceled") + } + + // Merge with existing CIDRs. + merged := appendMissing(connector.AllowedCIDRs(res), toAdd) + if err := connector.Update(ctx, client, res, merged); err != nil { + return fmt.Errorf("updating allowed CIDRs: %w", err) + } + + return nil +} + +// waitForConnectivity dials endpoint in a retry loop until it succeeds or timeout expires. +func waitForConnectivity(ctx context.Context, writer format.Writer, endpoint string, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + spinner, err := writer.Spinner( + format.Progressf("⏳", "waiting for connectivity to %s", endpoint), + format.Progressf("✅", "connected to %s", endpoint), + ) + if err != nil { + return err + } + + _ = spinner.Start() + defer func() { _ = spinner.Stop() }() + + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + attemptCtx, attemptCancel := context.WithTimeout(ctx, 3*time.Second) + dialErr := dialTCP(attemptCtx, endpoint) + attemptCancel() + if dialErr == nil { + _ = spinner.Stop() + return nil + } + + select { + case <-ctx.Done(): + switch ctx.Err() { + case context.DeadlineExceeded: + msg := "timeout waiting for connectivity to %s" + spinner.StopFailMessage(format.Progressf("", msg, endpoint)) + _ = spinner.StopFail() + return fmt.Errorf(msg, endpoint) + default: + _ = spinner.StopFail() + return nil + } + case <-ticker.C: + } + } +} + +// checkPath verifies that the named CLI binary is installed and on PATH. +func (cmd serviceCmd) checkPath(name string) error { + if _, err := cmd.getLookPath()(name); err != nil { + return cli.ErrorWithContext(fmt.Errorf("%q CLI not found", name)). + WithSuggestions( + fmt.Sprintf("Install %q and ensure it is available in your PATH.", name), + ) + } + return nil +} + +// confirm prints a confirmation prompt. When stdin is not a TTY it opens /dev/tty +// so that piped input (e.g. SQL dumps) does not consume the prompt, mirroring +// the pattern used by git and ssh. +func (cmd serviceCmd) confirm(msg string) (bool, error) { + if !isatty.IsTerminal(os.Stdin.Fd()) { + tty, err := cmd.openTTY()() + if err == nil { + defer tty.Close() + return cmd.Confirm(format.NewReader(tty), msg) + } + } + return cmd.Confirm(cmd.Reader, msg) +} + +// getCredentials fetches the connection secret for the given resource and +// returns the first username/password pair found. +func getCredentials(ctx context.Context, client *api.Client, mg resource.Managed) (string, string, error) { + secret, err := get.ConnectionSecretMap(ctx, client, mg) + if err != nil { + return "", "", fmt.Errorf("getting connection secret: %w", err) + } + + for user, pw := range secret { + return user, string(pw), nil + } + + return "", "", fmt.Errorf("connection secret %q contains no credentials", mg.GetWriteConnectionSecretToReference().Name) +} + +// ipCoveredByCIDRs reports whether ip is contained in any of the given CIDRs. +func ipCoveredByCIDRs(ip netip.Addr, cidrs []meta.IPv4CIDR) *netip.Prefix { + for _, cidr := range cidrs { + p, err := netip.ParsePrefix(string(cidr)) + if err != nil { + continue + } + if p.Contains(ip) { + return &p + } + } + + return nil +} + +// cidrsPresent reports whether all of want are present in current. +func cidrsPresent(current []meta.IPv4CIDR, want []meta.IPv4CIDR) bool { + set := make(map[meta.IPv4CIDR]struct{}, len(current)) + for _, c := range current { + set[c] = struct{}{} + } + for _, w := range want { + if _, ok := set[w]; !ok { + return false + } + } + return true +} + +// appendMissing appends any CIDRs from add that are not already in current. +func appendMissing(current []meta.IPv4CIDR, add []meta.IPv4CIDR) []meta.IPv4CIDR { + set := make(map[meta.IPv4CIDR]struct{}, len(current)) + for _, c := range current { + set[c] = struct{}{} + } + result := append([]meta.IPv4CIDR(nil), current...) + for _, a := range add { + if _, ok := set[a]; !ok { + result = append(result, a) + } + } + return result +} + +// dialTCP opens a single TCP connection to endpoint, respecting ctx for +// cancellation and deadline. +func dialTCP(ctx context.Context, endpoint string) error { + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", endpoint) + if err != nil { + return err + } + _ = conn.Close() + return nil +} + +// quickDial attempts a single TCP connection with a short timeout. +// Returns true when the endpoint is immediately reachable. +func quickDial(ctx context.Context, endpoint string) bool { + ctx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + return dialTCP(ctx, endpoint) == nil +} + +// createTempDir creates a temporary working directory for nctl runtime files +// and returns its path along with a cleanup function that removes it. +func createTempDir() (string, func(), error) { + dir, err := os.MkdirTemp("", "nctl-*") + if err != nil { + return "", func() {}, fmt.Errorf("creating temp dir: %w", err) + } + return dir, func() { _ = os.RemoveAll(dir) }, nil +} + +// writeCACert decodes a base64-encoded PEM CA certificate and writes it into +// dir. Returns the file path, or an empty string if caCert is empty. +func writeCACert(dir, caCert string) (string, error) { + if dir == "" || caCert == "" { + return "", nil + } + + path := filepath.Join(dir, "ca.pem") + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0o600) + if err != nil { + return "", fmt.Errorf("creating CA cert temp file %q: %w", path, err) + } + defer f.Close() + + if err := get.WriteBase64(f, caCert); err != nil { + return "", fmt.Errorf("writing CA cert %q: %w", path, err) + } + + return f.Name(), nil +} diff --git a/exec/cmd_test.go b/exec/cmd_test.go new file mode 100644 index 00000000..e24a6cac --- /dev/null +++ b/exec/cmd_test.go @@ -0,0 +1,74 @@ +package exec + +import ( + "bytes" + "context" + "fmt" + "io" + "os/exec" + "strings" + "time" + + meta "github.com/ninech/apis/meta/v1alpha1" + "github.com/ninech/nctl/internal/format" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// capturingCmd records the exec.Cmd passed to runCommand. +type capturingCmd struct { + cmd *exec.Cmd +} + +// testSecret creates a corev1.Secret with a single username→password entry. +func testSecret(name, namespace, user, password string) *corev1.Secret { + return &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + Data: map[string][]byte{ + user: []byte(password), + }, + } +} + +// testDatabaseCmd returns a capturingCmd and a databaseCmd wired with no-op +// writer/reader and test-friendly function fields. +// When cidrs is non-nil those CIDRs are used; when nil the IP detection is +// triggered only for instance resources (which is safe to use in tests if the +// connector returns nil from AllowedCIDRs). +func testDatabaseCmd(name string, cidrs *[]meta.IPv4CIDR) (*capturingCmd, serviceCmd) { + return testDatabaseCmdConfirmed(name, cidrs, false) +} + +// testDatabaseCmdConfirmed is like testDatabaseCmd but pre-seeds the reader +// with "y\n" so that confirmation prompts are auto-accepted. +func testDatabaseCmdConfirmed(name string, cidrs *[]meta.IPv4CIDR, confirmed bool) (*capturingCmd, serviceCmd) { + var reader io.Reader = &bytes.Buffer{} + if confirmed { + reader = strings.NewReader("y\n") + } + cap := &capturingCmd{} + cmd := serviceCmd{ + resourceCmd: resourceCmd{Name: name}, + Writer: format.NewWriter(&bytes.Buffer{}), + Reader: format.NewReader(reader), + AllowedCidrs: cidrs, + WaitTimeout: 0, + runCommand: func(c *exec.Cmd) error { + cap.cmd = c + return nil + }, + lookPath: func(file string) (string, error) { + return "/usr/bin/" + file, nil + }, + waitForConnectivity: func(_ context.Context, _ format.Writer, _ string, _ time.Duration) error { + return nil + }, + openTTYForConfirm: func() (io.ReadCloser, error) { + return nil, fmt.Errorf("no tty in tests") + }, + } + return cap, cmd +} diff --git a/exec/exec.go b/exec/exec.go index f37d6a41..b1d82df2 100644 --- a/exec/exec.go +++ b/exec/exec.go @@ -1,10 +1,16 @@ // Package exec provides the implementation for the exec command. package exec +// Cmd holds all exec sub-commands. type Cmd struct { - Application applicationCmd `cmd:"" group:"deplo.io" aliases:"app,application" name:"application" help:"Execute a command or shell in a deplo.io application."` + Application applicationCmd `cmd:"" group:"deplo.io" aliases:"app,application" name:"application" help:"Execute a command or shell in a deplo.io application."` + Postgres postgresCmd `cmd:"" group:"storage.nine.ch" name:"postgres" help:"Connect to a PostgreSQL instance."` + PostgresDatabase postgresDatabaseCmd `cmd:"" group:"storage.nine.ch" name:"postgresdatabase" help:"Connect to a PostgreSQL database."` + MySQL mysqlCmd `cmd:"" group:"storage.nine.ch" name:"mysql" help:"Connect to a MySQL instance."` + MySQLDatabase mysqlDatabaseCmd `cmd:"" group:"storage.nine.ch" name:"mysqldatabase" help:"Connect to a MySQL database."` + KeyValueStore kvsCmd `cmd:"" group:"storage.nine.ch" name:"keyvaluestore" aliases:"kvs" help:"Connect to a KeyValueStore instance."` } type resourceCmd struct { - Name string `arg:"" completion-predictor:"resource_name" help:"Name of the application to exec command/shell in." required:""` + Name string `arg:"" completion-predictor:"resource_name" help:"Name of the resource." required:""` } diff --git a/exec/keyvaluestore.go b/exec/keyvaluestore.go new file mode 100644 index 00000000..54a7b53f --- /dev/null +++ b/exec/keyvaluestore.go @@ -0,0 +1,104 @@ +package exec + +import ( + "context" + "fmt" + "os/exec" + + meta "github.com/ninech/apis/meta/v1alpha1" + storage "github.com/ninech/apis/storage/v1alpha1" + "github.com/ninech/nctl/api" + "github.com/ninech/nctl/internal/cli" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const kvsPort = "6379" + +type kvsCmd struct { + serviceCmd +} + +// Help displays usage examples for the keyvaluestore exec command. +func (cmd kvsCmd) Help() string { + return `Examples: + # Connect to a KeyValueStore instance interactively + nctl exec keyvaluestore mykvs + + # Pass extra flags to redis-cli (after --) + nctl exec keyvaluestore mykvs -- --no-auth-warning +` +} + +func (cmd *kvsCmd) Run(ctx context.Context, client *api.Client) error { + kvs := &storage.KeyValueStore{ + ObjectMeta: metav1.ObjectMeta{ + Name: cmd.Name, + Namespace: client.Project, + }, + } + if err := client.Get(ctx, client.Name(cmd.Name), kvs); err != nil { + return fmt.Errorf("getting keyvaluestore %q: %w", cmd.Name, err) + } + return connectAndExec(ctx, client, kvs, kvsConnector{}, cmd.serviceCmd) +} + +// kvsConnector implements ServiceConnector for storage.KeyValueStore instances. +type kvsConnector struct{} + +func (kvsConnector) Command() string { return "redis-cli" } + +func (kvsConnector) Endpoint(kvs *storage.KeyValueStore) string { + if kvs.Status.AtProvider.FQDN == "" { + return "" + } + return kvs.Status.AtProvider.FQDN + ":" + kvsPort +} + +func (kvsConnector) AllowedCIDRs(kvs *storage.KeyValueStore) []meta.IPv4CIDR { + return kvs.Spec.ForProvider.AllowedCIDRs +} + +func (kvsConnector) Update(ctx context.Context, client *api.Client, kvs *storage.KeyValueStore, cidrs []meta.IPv4CIDR) error { + current := &storage.KeyValueStore{} + if err := client.Get(ctx, api.ObjectName(kvs), current); err != nil { + return err + } + + if current.Spec.ForProvider.PublicNetworkingEnabled != nil && !*current.Spec.ForProvider.PublicNetworkingEnabled { + return cli.ErrorWithContext(fmt.Errorf("public networking is disabled for keyvaluestore %q", kvs.GetName())). + WithSuggestions( + fmt.Sprintf("Enable it with: %s update keyvaluestore %s --public-networking", cli.Name, kvs.GetName()), + ) + } + + current.Spec.ForProvider.AllowedCIDRs = cidrs + return client.Update(ctx, current) +} + +// NewCmd builds the redis-cli command. The auth token is passed via REDISCLI_AUTH +// rather than -a so it does not appear in the process argument list. +func (kvsConnector) NewCmd(ctx context.Context, kvs *storage.KeyValueStore, _ string, pw string) (*exec.Cmd, func(), error) { + dir, cleanup, err := createTempDir() + if err != nil { + return nil, func() {}, err + } + + caPath, err := writeCACert(dir, kvs.Status.AtProvider.CACert) + if err != nil { + cleanup() + return nil, func() {}, err + } + + args := []string{ + "-h", kvs.Status.AtProvider.FQDN, + "-p", kvsPort, + "--tls", + } + if caPath != "" { + args = append(args, "--cacert", caPath) + } + + cmd := exec.CommandContext(ctx, "redis-cli", args...) + cmd.Env = []string{"REDISCLI_AUTH=" + pw} + return cmd, cleanup, nil +} diff --git a/exec/keyvaluestore_test.go b/exec/keyvaluestore_test.go new file mode 100644 index 00000000..02d4d436 --- /dev/null +++ b/exec/keyvaluestore_test.go @@ -0,0 +1,148 @@ +package exec + +import ( + "context" + "os/exec" + "strings" + "testing" + + meta "github.com/ninech/apis/meta/v1alpha1" + storage "github.com/ninech/apis/storage/v1alpha1" + "github.com/ninech/nctl/api" + "github.com/ninech/nctl/internal/test" + runtimeclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" +) + +func TestKVSCmd(t *testing.T) { + t.Parallel() + + const ( + kvsName = "mykvs" + kvsFQDN = "mykvs.example.com" + kvsToken = "supersecrettoken" + ) + + cidr := []meta.IPv4CIDR{"203.0.113.5/32"} + pubNet := true + + ready := test.KeyValueStore(kvsName, test.DefaultProject, "nine-es34") + ready.Status.AtProvider.FQDN = kvsFQDN + ready.Spec.ForProvider.AllowedCIDRs = []meta.IPv4CIDR{"10.0.0.1/32"} + ready.Spec.ForProvider.PublicNetworkingEnabled = &pubNet + + pubNetFalse := false + pubNetDisabled := test.KeyValueStore("no-public", test.DefaultProject, "nine-es34") + pubNetDisabled.Status.AtProvider.FQDN = "no-public.example.com" + pubNetDisabled.Spec.ForProvider.PublicNetworkingEnabled = &pubNetFalse + pubNetDisabled.Spec.ForProvider.AllowedCIDRs = []meta.IPv4CIDR{} + + notReady := test.KeyValueStore("notready", test.DefaultProject, "nine-es34") + + // KVS secret: single key with auth token as value. + secret := testSecret(kvsName, test.DefaultProject, "token", kvsToken) + + _, notFoundCmd := testDatabaseCmd("doesnotexist", &cidr) + _, notReadyCmd := testDatabaseCmd("notready", &cidr) + alreadyCap, alreadyPresentCmd := testDatabaseCmd(kvsName, &[]meta.IPv4CIDR{"10.0.0.1/32"}) + _, newCidrCmd := testDatabaseCmdConfirmed(kvsName, &cidr, true) + _, pubNetDisabledCmd := testDatabaseCmdConfirmed("no-public", &cidr, true) + tokenCap, tokenCmd := testDatabaseCmd(kvsName, &[]meta.IPv4CIDR{"10.0.0.1/32"}) + + tests := []struct { + name string + cmd kvsCmd + cap *capturingCmd + wantErr bool + errContains string + wantUpdate bool + check func(t *testing.T, cmd *exec.Cmd) + }{ + { + name: "resource not found", + cmd: kvsCmd{serviceCmd: notFoundCmd}, + wantErr: true, + }, + { + name: "resource not ready", + cmd: kvsCmd{serviceCmd: notReadyCmd}, + wantErr: true, + errContains: "not ready", + }, + { + name: "cidr already present skips update", + cmd: kvsCmd{serviceCmd: alreadyPresentCmd}, + cap: alreadyCap, + check: func(t *testing.T, cmd *exec.Cmd) { + t.Helper() + if !strings.Contains(strings.Join(cmd.Args, " "), kvsFQDN) { + t.Errorf("expected FQDN %q in args %v", kvsFQDN, cmd.Args) + } + }, + }, + { + name: "new cidr triggers update", + cmd: kvsCmd{serviceCmd: newCidrCmd}, + wantUpdate: true, + }, + { + name: "public networking disabled returns error", + cmd: kvsCmd{serviceCmd: pubNetDisabledCmd}, + wantErr: true, + errContains: "networking is disabled", + }, + { + name: "token passed securely via env", + cmd: kvsCmd{serviceCmd: tokenCmd}, + cap: tokenCap, + check: func(t *testing.T, cmd *exec.Cmd) { + t.Helper() + if strings.Contains(strings.Join(cmd.Args, " "), kvsToken) { + t.Errorf("token must not appear in args %v", cmd.Args) + } + if !containsEnv(cmd.Env, "REDISCLI_AUTH="+kvsToken) { + t.Errorf("expected REDISCLI_AUTH env var, got %v", cmd.Env) + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + updateCalled := false + apiClient := test.SetupClient(t, + test.WithObjects(ready, notReady, pubNetDisabled, secret), + test.WithInterceptorFuncs(interceptor.Funcs{ + Update: func(ctx context.Context, c runtimeclient.WithWatch, obj runtimeclient.Object, opts ...runtimeclient.UpdateOption) error { + updateCalled = true + return c.Update(ctx, obj, opts...) + }, + }), + ) + + err := tc.cmd.Run(t.Context(), apiClient) + + if (err != nil) != tc.wantErr { + t.Fatalf("Run() error = %v, wantErr %v", err, tc.wantErr) + } + if tc.errContains != "" && (err == nil || !strings.Contains(err.Error(), tc.errContains)) { + t.Errorf("expected error containing %q, got %v", tc.errContains, err) + } + if tc.wantUpdate && !updateCalled { + t.Error("expected Update to be called for CIDR addition") + } + if !tc.wantErr && tc.check != nil { + tc.check(t, tc.cap.cmd) + } + if tc.wantUpdate { + kvs := &storage.KeyValueStore{} + if err := apiClient.Get(t.Context(), api.ObjectName(ready), kvs); err != nil { + t.Fatalf("getting kvs: %v", err) + } + if !cidrsPresent(kvs.Spec.ForProvider.AllowedCIDRs, cidr) { + t.Errorf("expected CIDR %v to be added, got %v", cidr, kvs.Spec.ForProvider.AllowedCIDRs) + } + } + }) + } +} diff --git a/exec/mysql.go b/exec/mysql.go new file mode 100644 index 00000000..36870052 --- /dev/null +++ b/exec/mysql.go @@ -0,0 +1,155 @@ +package exec + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + + meta "github.com/ninech/apis/meta/v1alpha1" + storage "github.com/ninech/apis/storage/v1alpha1" + "github.com/ninech/nctl/api" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const ( + mysqlPort = "3306" + mysqlCommand = "mysql" +) + +type mysqlCmd struct { + serviceCmd + Database string `name:"database" short:"d" completion-predictor:"mysql_databases" help:"Database name to connect to."` +} + +// Help displays usage examples for the mysql exec command. +func (cmd mysqlCmd) Help() string { + return `Examples: + # Connect to a MySQL instance interactively + nctl exec mysql myinstance + + # Connect to a specific database + nctl exec mysql myinstance -d mydb + + # Import a SQL dump via pipe + cat dump.sql | nctl exec mysql myinstance + + # Pass extra flags to mysql (after --) + nctl exec mysql myinstance -- --batch +` +} + +func (cmd *mysqlCmd) Run(ctx context.Context, client *api.Client) error { + my := &storage.MySQL{ + ObjectMeta: metav1.ObjectMeta{ + Name: cmd.Name, + Namespace: client.Project, + }, + } + if err := client.Get(ctx, client.Name(cmd.Name), my); err != nil { + return fmt.Errorf("getting mysql %q: %w", cmd.Name, err) + } + return connectAndExec(ctx, client, my, mysqlConnector{database: cmd.Database}, cmd.serviceCmd) +} + +// mysqlConnector implements cmdExecutor for storage.MySQL instances. +type mysqlConnector struct { + database string +} + +func (mysqlConnector) Command() string { return mysqlCommand } + +func (mysqlConnector) Endpoint(my *storage.MySQL) string { + if my.Status.AtProvider.FQDN == "" { + return "" + } + return my.Status.AtProvider.FQDN + ":" + mysqlPort +} + +func (mysqlConnector) AllowedCIDRs(my *storage.MySQL) []meta.IPv4CIDR { + return my.Spec.ForProvider.AllowedCIDRs +} + +func (mysqlConnector) Update(ctx context.Context, client *api.Client, my *storage.MySQL, cidrs []meta.IPv4CIDR) error { + current := &storage.MySQL{} + if err := client.Get(ctx, api.ObjectName(my), current); err != nil { + return err + } + current.Spec.ForProvider.AllowedCIDRs = cidrs + return client.Update(ctx, current) +} + +// NewCmd builds the mysql command. Credentials are passed via --defaults-extra-file +// rather than -u/-p flags so they do not appear in the process argument list. +func (c mysqlConnector) NewCmd(ctx context.Context, my *storage.MySQL, user, pw string) (*exec.Cmd, func(), error) { + return newMySQLCmd(ctx, my.Status.AtProvider.FQDN, c.database, my.Status.AtProvider.CACert, user, pw) +} + +// newMySQLCmd returns an exec.Cmd for mysql with credentials in a temp options file. +// When a CA cert is provided the connection uses VERIFY_CA, otherwise REQUIRED. +func newMySQLCmd(ctx context.Context, fqdn, dbName, caCertBase64, user, pw string) (*exec.Cmd, func(), error) { + dir, cleanup, err := createTempDir() + if err != nil { + return nil, func() {}, err + } + + caPath, err := writeCACert(dir, caCertBase64) + if err != nil { + cleanup() + return nil, func() {}, err + } + + cfgPath, err := writeMySQLConfig(dir, user, pw) + if err != nil { + cleanup() + return nil, func() {}, err + } + + // --defaults-extra-file must precede all other options. + args := []string{ + "--defaults-extra-file=" + cfgPath, + "-h", fqdn, + "-P", mysqlPort, + } + if caPath != "" { + args = append(args, "--ssl-ca="+caPath, "--ssl-mode=VERIFY_CA") + } else { + args = append(args, "--ssl-mode=REQUIRED") + } + if dbName != "" { + args = append(args, dbName) + } + + return exec.CommandContext(ctx, mysqlCommand, args...), cleanup, nil +} + +// writeMySQLConfig writes a temporary MySQL options file into dir containing +// the given credentials. The file is mode 0600 so other local users cannot read it. +func writeMySQLConfig(dir, user, pw string) (string, error) { + if dir == "" { + return "", nil + } + + path := filepath.Join(dir, "my.cnf") + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0o600) + if err != nil { + return "", fmt.Errorf("creating MySQL config temp file %q: %w", path, err) + } + defer f.Close() + + if _, err = fmt.Fprintf(f, "[client]\nuser=%s\npassword=%s\n", mysqlConfigEscape(user), mysqlConfigEscape(pw)); err != nil { + return "", fmt.Errorf("writing MySQL config %q: %w", path, err) + } + + return f.Name(), nil +} + +// mysqlConfigEscape escapes a value for use in a MySQL option file. +// Values are double-quoted; internal double quotes and backslashes are escaped. +func mysqlConfigEscape(s string) string { + s = strings.ReplaceAll(s, `\`, `\\`) + s = strings.ReplaceAll(s, `"`, `\"`) + return `"` + s + `"` +} diff --git a/exec/mysql_test.go b/exec/mysql_test.go new file mode 100644 index 00000000..e60d77b0 --- /dev/null +++ b/exec/mysql_test.go @@ -0,0 +1,178 @@ +package exec + +import ( + "context" + "encoding/base64" + "os/exec" + "strings" + "testing" + + meta "github.com/ninech/apis/meta/v1alpha1" + storage "github.com/ninech/apis/storage/v1alpha1" + "github.com/ninech/nctl/api" + "github.com/ninech/nctl/internal/test" + runtimeclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" +) + +func TestMySQLCmd(t *testing.T) { + t.Parallel() + + const ( + myName = "mymy" + myFQDN = "mymy.example.com" + myUser = "root" + myPass = "rootpass" + ) + + cidr := []meta.IPv4CIDR{"203.0.113.5/32"} + + ready := test.MySQL(myName, test.DefaultProject, "nine-es34") + ready.Status.AtProvider.FQDN = myFQDN + ready.Spec.ForProvider.AllowedCIDRs = []meta.IPv4CIDR{"10.0.0.1/32"} + + readyWithCA := test.MySQL(myName+"-ca", test.DefaultProject, "nine-es34") + readyWithCA.Status.AtProvider.FQDN = myFQDN + readyWithCA.Status.AtProvider.CACert = base64.StdEncoding.EncodeToString([]byte("fake-ca-cert")) + readyWithCA.Spec.ForProvider.AllowedCIDRs = []meta.IPv4CIDR{"10.0.0.1/32"} + + notReady := test.MySQL("notready", test.DefaultProject, "nine-es34") + + secret := testSecret(myName, test.DefaultProject, myUser, myPass) + secretWithCA := testSecret(myName+"-ca", test.DefaultProject, myUser, myPass) + + _, notFoundCmd := testDatabaseCmd("doesnotexist", &cidr) + _, notReadyCmd := testDatabaseCmd("notready", &cidr) + alreadyCap, alreadyPresentCmd := testDatabaseCmd(myName, &[]meta.IPv4CIDR{"10.0.0.1/32"}) + _, newCidrCmd := testDatabaseCmdConfirmed(myName, &cidr, true) + credsCap, credsCmd := testDatabaseCmd(myName, &[]meta.IPv4CIDR{"10.0.0.1/32"}) + dbCap, dbCmd := testDatabaseCmd(myName, &[]meta.IPv4CIDR{"10.0.0.1/32"}) + sslCap, sslCmd := testDatabaseCmd(myName+"-ca", &[]meta.IPv4CIDR{"10.0.0.1/32"}) + + tests := []struct { + name string + cmd mysqlCmd + cap *capturingCmd + objects []runtimeclient.Object + wantErr bool + errContains string + wantUpdate bool + check func(t *testing.T, cmd *exec.Cmd) + }{ + { + name: "resource not found", + cmd: mysqlCmd{serviceCmd: notFoundCmd}, + wantErr: true, + }, + { + name: "resource not ready", + cmd: mysqlCmd{serviceCmd: notReadyCmd}, + wantErr: true, + errContains: "not ready", + }, + { + name: "cidr already present skips update", + cmd: mysqlCmd{serviceCmd: alreadyPresentCmd}, + cap: alreadyCap, + check: func(t *testing.T, cmd *exec.Cmd) { + t.Helper() + if !strings.Contains(strings.Join(cmd.Args, " "), myFQDN) { + t.Errorf("expected FQDN %q in args %v", myFQDN, cmd.Args) + } + }, + }, + { + name: "new cidr triggers update", + cmd: mysqlCmd{serviceCmd: newCidrCmd}, + wantUpdate: true, + }, + { + name: "credentials passed securely", + cmd: mysqlCmd{serviceCmd: credsCmd}, + cap: credsCap, + check: func(t *testing.T, cmd *exec.Cmd) { + t.Helper() + argsStr := strings.Join(cmd.Args, " ") + if strings.Contains(argsStr, myPass) { + t.Errorf("password must not appear in args %v", cmd.Args) + } + if strings.Contains(argsStr, myUser) && !strings.Contains(argsStr, "--defaults-extra-file") { + t.Errorf("user must not appear in plain args %v", cmd.Args) + } + if !strings.Contains(argsStr, "--defaults-extra-file=") { + t.Errorf("expected --defaults-extra-file in args %v", cmd.Args) + } + }, + }, + { + name: "custom database appears in args", + cmd: mysqlCmd{serviceCmd: dbCmd, Database: "mydb"}, + cap: dbCap, + check: func(t *testing.T, cmd *exec.Cmd) { + t.Helper() + if !strings.Contains(strings.Join(cmd.Args, " "), "mydb") { + t.Errorf("expected database %q in args %v", "mydb", cmd.Args) + } + }, + }, + { + name: "ssl mode is VERIFY_CA when CA cert is present", + cmd: mysqlCmd{serviceCmd: sslCmd}, + cap: sslCap, + objects: []runtimeclient.Object{readyWithCA, secretWithCA}, + check: func(t *testing.T, cmd *exec.Cmd) { + t.Helper() + argsStr := strings.Join(cmd.Args, " ") + if !strings.Contains(argsStr, "--ssl-mode=VERIFY_CA") { + t.Errorf("expected --ssl-mode=VERIFY_CA in args %v", cmd.Args) + } + if strings.Contains(argsStr, "--ssl-mode=REQUIRED") { + t.Errorf("unexpected --ssl-mode=REQUIRED when CA cert is present, args %v", cmd.Args) + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + objs := []runtimeclient.Object{ready, notReady, secret} + if len(tc.objects) > 0 { + objs = tc.objects + } + updateCalled := false + apiClient := test.SetupClient(t, + test.WithObjects(objs...), + test.WithInterceptorFuncs(interceptor.Funcs{ + Update: func(ctx context.Context, c runtimeclient.WithWatch, obj runtimeclient.Object, opts ...runtimeclient.UpdateOption) error { + updateCalled = true + return c.Update(ctx, obj, opts...) + }, + }), + ) + + err := tc.cmd.Run(t.Context(), apiClient) + + if (err != nil) != tc.wantErr { + t.Fatalf("Run() error = %v, wantErr %v", err, tc.wantErr) + } + if tc.errContains != "" && (err == nil || !strings.Contains(err.Error(), tc.errContains)) { + t.Errorf("expected error containing %q, got %v", tc.errContains, err) + } + if tc.wantUpdate && !updateCalled { + t.Error("expected Update to be called for CIDR addition") + } + if !tc.wantErr && tc.check != nil { + tc.check(t, tc.cap.cmd) + } + if tc.wantUpdate { + my := &storage.MySQL{} + if err := apiClient.Get(t.Context(), api.ObjectName(ready), my); err != nil { + t.Fatalf("getting mysql: %v", err) + } + if !cidrsPresent(my.Spec.ForProvider.AllowedCIDRs, cidr) { + t.Errorf("expected CIDR %v to be added, got %v", cidr, my.Spec.ForProvider.AllowedCIDRs) + } + } + }) + } +} diff --git a/exec/mysqldatabase.go b/exec/mysqldatabase.go new file mode 100644 index 00000000..1ac197c8 --- /dev/null +++ b/exec/mysqldatabase.go @@ -0,0 +1,64 @@ +package exec + +import ( + "context" + "fmt" + "os/exec" + + storage "github.com/ninech/apis/storage/v1alpha1" + "github.com/ninech/nctl/api" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type mysqlDatabaseCmd struct { + serviceCmd +} + +// Help displays usage examples for the mysqldatabase exec command. +func (cmd mysqlDatabaseCmd) Help() string { + return `Examples: + # Connect to a MySQL database interactively + nctl exec mysqldatabase mydb + + # Import a SQL dump via pipe + cat dump.sql | nctl exec mysqldatabase mydb +` +} + +// Run connects to the named MySQLDatabase resource. +func (cmd *mysqlDatabaseCmd) Run(ctx context.Context, client *api.Client) error { + db := &storage.MySQLDatabase{ + ObjectMeta: metav1.ObjectMeta{ + Name: cmd.Name, + Namespace: client.Project, + }, + } + if err := client.Get(ctx, client.Name(cmd.Name), db); err != nil { + return fmt.Errorf("getting mysqldatabase %q: %w", cmd.Name, err) + } + return connectAndExec(ctx, client, db, mysqlDatabaseConnector{}, cmd.serviceCmd) +} + +// mysqlDatabaseConnector implements cmdExecutor for storage.MySQLDatabase resources. +// It does not implement accessManager because the parent MySQL instance manages CIDRs. +type mysqlDatabaseConnector struct{} + +// Command returns the CLI binary name for connecting to a MySQL database. +func (mysqlDatabaseConnector) Command() string { return mysqlCommand } + +// Endpoint returns the host:port for the TCP connectivity check. +func (mysqlDatabaseConnector) Endpoint(db *storage.MySQLDatabase) string { + if db.Status.AtProvider.FQDN == "" { + return "" + } + return db.Status.AtProvider.FQDN + ":" + mysqlPort +} + +// NewCmd builds the mysql command for connecting to a MySQLDatabase. +func (mysqlDatabaseConnector) NewCmd(ctx context.Context, db *storage.MySQLDatabase, user, pw string) (*exec.Cmd, func(), error) { + dbName := db.Status.AtProvider.Name + if dbName == "" { + dbName = user + } + return newMySQLCmd(ctx, db.Status.AtProvider.FQDN, dbName, db.Status.AtProvider.CACert, user, pw) +} diff --git a/exec/mysqldatabase_test.go b/exec/mysqldatabase_test.go new file mode 100644 index 00000000..893f6acf --- /dev/null +++ b/exec/mysqldatabase_test.go @@ -0,0 +1,104 @@ +package exec + +import ( + "context" + "os/exec" + "strings" + "testing" + + "github.com/ninech/nctl/internal/test" + runtimeclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" +) + +func TestMySQLDatabaseCmd(t *testing.T) { + t.Parallel() + + const ( + myDBName = "mydb" + myDBFQDN = "mydb.example.com" + myDBUser = "mydb" + myDBPass = "mydbpass" + ) + + ready := test.MySQLDatabase(myDBName, test.DefaultProject, "nine-es34") + ready.Status.AtProvider.FQDN = myDBFQDN + ready.Status.AtProvider.Name = myDBName + + notReady := test.MySQLDatabase("notready", test.DefaultProject, "nine-es34") + + secret := testSecret(myDBName, test.DefaultProject, myDBUser, myDBPass) + + _, notFoundCmd := testDatabaseCmd("doesnotexist", nil) + _, notReadyCmd := testDatabaseCmd("notready", nil) + connectCap, connectCmd := testDatabaseCmd(myDBName, nil) + + tests := []struct { + name string + cmd mysqlDatabaseCmd + cap *capturingCmd + wantErr bool + errContains string + check func(t *testing.T, cmd *exec.Cmd) + }{ + { + name: "resource not found", + cmd: mysqlDatabaseCmd{serviceCmd: notFoundCmd}, + wantErr: true, + }, + { + name: "resource not ready", + cmd: mysqlDatabaseCmd{serviceCmd: notReadyCmd}, + wantErr: true, + errContains: "not ready", + }, + { + name: "connects without cidr management", + cmd: mysqlDatabaseCmd{serviceCmd: connectCmd}, + cap: connectCap, + check: func(t *testing.T, cmd *exec.Cmd) { + t.Helper() + argsStr := strings.Join(cmd.Args, " ") + if !strings.Contains(argsStr, myDBFQDN) { + t.Errorf("expected FQDN %q in args %v", myDBFQDN, cmd.Args) + } + if !strings.Contains(argsStr, "--defaults-extra-file=") { + t.Errorf("expected --defaults-extra-file in args %v", cmd.Args) + } + if strings.Contains(argsStr, myDBPass) { + t.Errorf("password must not appear in args %v", cmd.Args) + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + updateCalled := false + apiClient := test.SetupClient(t, + test.WithObjects(ready, notReady, secret), + test.WithInterceptorFuncs(interceptor.Funcs{ + Update: func(ctx context.Context, c runtimeclient.WithWatch, obj runtimeclient.Object, opts ...runtimeclient.UpdateOption) error { + updateCalled = true + return c.Update(ctx, obj, opts...) + }, + }), + ) + + err := tc.cmd.Run(t.Context(), apiClient) + + if (err != nil) != tc.wantErr { + t.Fatalf("Run() error = %v, wantErr %v", err, tc.wantErr) + } + if tc.errContains != "" && (err == nil || !strings.Contains(err.Error(), tc.errContains)) { + t.Errorf("expected error containing %q, got %v", tc.errContains, err) + } + if updateCalled { + t.Error("Update must not be called for child database resources") + } + if !tc.wantErr && tc.check != nil { + tc.check(t, tc.cap.cmd) + } + }) + } +} diff --git a/exec/postgres.go b/exec/postgres.go new file mode 100644 index 00000000..941f4659 --- /dev/null +++ b/exec/postgres.go @@ -0,0 +1,128 @@ +package exec + +import ( + "context" + "fmt" + "net" + "net/url" + "os/exec" + + meta "github.com/ninech/apis/meta/v1alpha1" + storage "github.com/ninech/apis/storage/v1alpha1" + "github.com/ninech/nctl/api" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const ( + postgresPort = "5432" + postgresCommand = "psql" +) + +type postgresCmd struct { + serviceCmd + Database string `name:"database" short:"d" default:"postgres" completion-predictor:"postgres_databases" help:"Database name to connect to."` +} + +// Help displays usage examples for the postgres exec command. +func (cmd postgresCmd) Help() string { + return `Examples: + # Connect to a PostgreSQL instance interactively + nctl exec postgres myinstance + + # Connect to a specific database + nctl exec postgres myinstance -d mydb + + # Import a SQL dump via pipe + cat dump.sql | nctl exec postgres myinstance + + # Pass extra flags to psql (after --) + nctl exec postgres myinstance -- --no-pager +` +} + +func (cmd *postgresCmd) Run(ctx context.Context, client *api.Client) error { + pg := &storage.Postgres{ + ObjectMeta: metav1.ObjectMeta{ + Name: cmd.Name, + Namespace: client.Project, + }, + } + if err := client.Get(ctx, client.Name(cmd.Name), pg); err != nil { + return fmt.Errorf("getting postgres %q: %w", cmd.Name, err) + } + return connectAndExec(ctx, client, pg, postgresConnector{database: cmd.Database}, cmd.serviceCmd) +} + +// postgresConnector implements cmdExecutor for storage.Postgres instances. +type postgresConnector struct { + database string +} + +func (postgresConnector) Command() string { return postgresCommand } + +func (postgresConnector) Endpoint(pg *storage.Postgres) string { + if pg.Status.AtProvider.FQDN == "" { + return "" + } + return net.JoinHostPort(pg.Status.AtProvider.FQDN, postgresPort) +} + +func (postgresConnector) AllowedCIDRs(pg *storage.Postgres) []meta.IPv4CIDR { + return pg.Spec.ForProvider.AllowedCIDRs +} + +func (postgresConnector) Update(ctx context.Context, client *api.Client, pg *storage.Postgres, cidrs []meta.IPv4CIDR) error { + current := &storage.Postgres{} + if err := client.Get(ctx, api.ObjectName(pg), current); err != nil { + return err + } + current.Spec.ForProvider.AllowedCIDRs = cidrs + return client.Update(ctx, current) +} + +// NewCmd builds the psql command with PGPASSWORD passed via env instead of the connection URL. +func (c postgresConnector) NewCmd(ctx context.Context, pg *storage.Postgres, user, pw string) (*exec.Cmd, func(), error) { + return newPsqlCmd(ctx, pg.Status.AtProvider.FQDN, c.database, pg.Status.AtProvider.CACert, user, pw) +} + +// newPsqlCmd returns an exec.Cmd for psql. The password is passed via PGPASSWORD +// rather than the connection URL so it does not appear in the process argument list. +func newPsqlCmd(ctx context.Context, fqdn, dbName, caCertBase64, user, pw string) (*exec.Cmd, func(), error) { + dir, cleanup, err := createTempDir() + if err != nil { + return nil, func() {}, err + } + + caPath, err := writeCACert(dir, caCertBase64) + if err != nil { + cleanup() + return nil, func() {}, err + } + + connURL := postgresConnectionURL(fqdn, user, dbName, caPath) + cmd := exec.CommandContext(ctx, postgresCommand, connURL.String()) + cmd.Env = []string{"PGPASSWORD=" + pw} + return cmd, cleanup, nil +} + +// postgresConnectionURL builds a psql connection URL without a password. +// sslmode is set to verify-ca when a CA cert path is provided, otherwise require. +func postgresConnectionURL(fqdn, user, db, caPath string) *url.URL { + if db == "" { + db = user + } + q := url.Values{} + if caPath != "" { + q.Set("sslmode", "verify-ca") + q.Set("sslrootcert", caPath) + } else { + q.Set("sslmode", "require") + } + return &url.URL{ + Scheme: "postgres", + Host: fqdn, + User: url.User(user), + Path: db, + RawQuery: q.Encode(), + } +} diff --git a/exec/postgres_test.go b/exec/postgres_test.go new file mode 100644 index 00000000..139c224f --- /dev/null +++ b/exec/postgres_test.go @@ -0,0 +1,167 @@ +package exec + +import ( + "context" + "os/exec" + "slices" + "strings" + "testing" + + meta "github.com/ninech/apis/meta/v1alpha1" + storage "github.com/ninech/apis/storage/v1alpha1" + "github.com/ninech/nctl/api" + "github.com/ninech/nctl/internal/test" + runtimeclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" +) + +func TestPostgresCmd(t *testing.T) { + t.Parallel() + + const ( + pgName = "mypg" + location = "nine-es34" + fqdn = "mypg.example.com" + pgUser = "admin" + pgPass = "secret" + ) + + cidr := []meta.IPv4CIDR{"203.0.113.5/32"} + + ready := test.Postgres(pgName, test.DefaultProject, location) + ready.Status.AtProvider.FQDN = fqdn + ready.Spec.ForProvider.AllowedCIDRs = []meta.IPv4CIDR{"10.0.0.1/32"} + + notReady := test.Postgres("notready", test.DefaultProject, location) + + secret := testSecret(pgName, test.DefaultProject, pgUser, pgPass) + + _, notFoundCmd := testDatabaseCmd("doesnotexist", &cidr) + _, notReadyCmd := testDatabaseCmd("notready", &cidr) + alreadyCap, alreadyPresentCmd := testDatabaseCmd(pgName, &[]meta.IPv4CIDR{"10.0.0.1/32"}) + newCidrCap, newCidrCmd := testDatabaseCmdConfirmed(pgName, &cidr, true) + credsCap, credsCmd := testDatabaseCmd(pgName, &[]meta.IPv4CIDR{"10.0.0.1/32"}) + dbCap, dbCmd := testDatabaseCmd(pgName, &[]meta.IPv4CIDR{"10.0.0.1/32"}) + + tests := []struct { + name string + cmd postgresCmd + cap *capturingCmd + wantErr bool + errContains string + wantUpdate bool + check func(t *testing.T, cmd *exec.Cmd) + }{ + { + name: "resource not found", + cmd: postgresCmd{serviceCmd: notFoundCmd}, + wantErr: true, + }, + { + name: "resource not ready", + cmd: postgresCmd{serviceCmd: notReadyCmd}, + wantErr: true, + errContains: "not ready", + }, + { + name: "cidr already present skips update", + cmd: postgresCmd{serviceCmd: alreadyPresentCmd}, + cap: alreadyCap, + check: func(t *testing.T, cmd *exec.Cmd) { + t.Helper() + if !strings.Contains(strings.Join(cmd.Args, " "), fqdn) { + t.Errorf("expected FQDN %q in args %v", fqdn, cmd.Args) + } + }, + }, + { + name: "new cidr triggers update", + cmd: postgresCmd{serviceCmd: newCidrCmd}, + cap: newCidrCap, + wantUpdate: true, + check: func(t *testing.T, cmd *exec.Cmd) { + t.Helper() + if !strings.Contains(strings.Join(cmd.Args, " "), fqdn) { + t.Errorf("expected FQDN %q in args %v", fqdn, cmd.Args) + } + }, + }, + { + name: "credentials passed securely", + cmd: postgresCmd{serviceCmd: credsCmd}, + cap: credsCap, + check: func(t *testing.T, cmd *exec.Cmd) { + t.Helper() + argsStr := strings.Join(cmd.Args, " ") + if strings.Contains(argsStr, pgPass) { + t.Errorf("password must not appear in args %v", cmd.Args) + } + if !strings.Contains(argsStr, pgUser) { + t.Errorf("expected user %q in args %v", pgUser, cmd.Args) + } + if !containsEnv(cmd.Env, "PGPASSWORD="+pgPass) { + t.Errorf("expected PGPASSWORD env var, got %v", cmd.Env) + } + }, + }, + { + name: "custom database appears in connection string", + cmd: postgresCmd{serviceCmd: dbCmd, Database: "mydb"}, + cap: dbCap, + check: func(t *testing.T, cmd *exec.Cmd) { + t.Helper() + if !strings.Contains(strings.Join(cmd.Args, " "), "/mydb") { + t.Errorf("expected database %q in args %v", "mydb", cmd.Args) + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + updateCalled := false + apiClient := test.SetupClient(t, + test.WithObjects(ready, notReady, secret), + test.WithInterceptorFuncs(interceptor.Funcs{ + Update: func(ctx context.Context, c runtimeclient.WithWatch, obj runtimeclient.Object, opts ...runtimeclient.UpdateOption) error { + updateCalled = true + return c.Update(ctx, obj, opts...) + }, + }), + ) + + err := tc.cmd.Run(t.Context(), apiClient) + + if (err != nil) != tc.wantErr { + t.Fatalf("Run() error = %v, wantErr %v", err, tc.wantErr) + } + if tc.errContains != "" && (err == nil || !strings.Contains(err.Error(), tc.errContains)) { + t.Errorf("expected error containing %q, got %v", tc.errContains, err) + } + if tc.wantUpdate && !updateCalled { + t.Error("expected Update to be called for CIDR addition") + } + if !tc.wantUpdate && !tc.wantErr && updateCalled { + t.Error("unexpected Update call when CIDR already present") + } + if !tc.wantErr && tc.check != nil { + tc.check(t, tc.cap.cmd) + } + + if tc.wantUpdate { + pg := &storage.Postgres{} + if err := apiClient.Get(t.Context(), api.ObjectName(ready), pg); err != nil { + t.Fatalf("getting postgres: %v", err) + } + if !cidrsPresent(pg.Spec.ForProvider.AllowedCIDRs, cidr) { + t.Errorf("expected CIDR %v to be added, got %v", cidr, pg.Spec.ForProvider.AllowedCIDRs) + } + } + }) + } +} + +// containsEnv reports whether the KEY=VALUE entry is present in env. +func containsEnv(env []string, entry string) bool { + return slices.Contains(env, entry) +} diff --git a/exec/postgresdatabase.go b/exec/postgresdatabase.go new file mode 100644 index 00000000..b6d12fbf --- /dev/null +++ b/exec/postgresdatabase.go @@ -0,0 +1,61 @@ +package exec + +import ( + "context" + "fmt" + "net" + "os/exec" + + storage "github.com/ninech/apis/storage/v1alpha1" + "github.com/ninech/nctl/api" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type postgresDatabaseCmd struct { + serviceCmd +} + +// Help displays usage examples for the postgresdatabase exec command. +func (cmd postgresDatabaseCmd) Help() string { + return `Examples: + # Connect to a PostgreSQL database interactively + nctl exec postgresdatabase mydb + + # Import a SQL dump via pipe + cat dump.sql | nctl exec postgresdatabase mydb +` +} + +// Run connects to the named PostgresDatabase resource. +func (cmd *postgresDatabaseCmd) Run(ctx context.Context, client *api.Client) error { + db := &storage.PostgresDatabase{ + ObjectMeta: metav1.ObjectMeta{ + Name: cmd.Name, + Namespace: client.Project, + }, + } + if err := client.Get(ctx, client.Name(cmd.Name), db); err != nil { + return fmt.Errorf("getting postgresdatabase %q: %w", cmd.Name, err) + } + return connectAndExec(ctx, client, db, postgresDatabaseConnector{}, cmd.serviceCmd) +} + +// postgresDatabaseConnector implements cmdExecutor for storage.PostgresDatabase resources. +// It does not implement accessManager because the parent Postgres instance manages CIDRs. +type postgresDatabaseConnector struct{} + +// Command returns the CLI binary name for connecting to a PostgreSQL database. +func (postgresDatabaseConnector) Command() string { return postgresCommand } + +// Endpoint returns the host:port for the TCP connectivity check. +func (postgresDatabaseConnector) Endpoint(db *storage.PostgresDatabase) string { + if db.Status.AtProvider.FQDN == "" { + return "" + } + return net.JoinHostPort(db.Status.AtProvider.FQDN, postgresPort) +} + +// NewCmd builds the psql command for connecting to a PostgresDatabase. +func (postgresDatabaseConnector) NewCmd(ctx context.Context, db *storage.PostgresDatabase, user, pw string) (*exec.Cmd, func(), error) { + return newPsqlCmd(ctx, db.Status.AtProvider.FQDN, db.Status.AtProvider.Name, db.Status.AtProvider.CACert, user, pw) +} diff --git a/exec/postgresdatabase_test.go b/exec/postgresdatabase_test.go new file mode 100644 index 00000000..353f3d66 --- /dev/null +++ b/exec/postgresdatabase_test.go @@ -0,0 +1,107 @@ +package exec + +import ( + "context" + "os/exec" + "strings" + "testing" + + "github.com/ninech/nctl/internal/test" + runtimeclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" +) + +func TestPostgresDatabaseCmd(t *testing.T) { + t.Parallel() + + const ( + dbName = "mydb" + dbFQDN = "mydb.example.com" + dbUser = "mydb" + dbPass = "dbsecret" + ) + + ready := test.PostgresDatabase(dbName, test.DefaultProject, "nine-es34") + ready.Status.AtProvider.FQDN = dbFQDN + ready.Status.AtProvider.Name = dbName + + notReady := test.PostgresDatabase("notready", test.DefaultProject, "nine-es34") + + secret := testSecret(dbName, test.DefaultProject, dbUser, dbPass) + + _, notFoundCmd := testDatabaseCmd("doesnotexist", nil) + _, notReadyCmd := testDatabaseCmd("notready", nil) + connectCap, connectCmd := testDatabaseCmd(dbName, nil) + + tests := []struct { + name string + cmd postgresDatabaseCmd + cap *capturingCmd + wantErr bool + errContains string + check func(t *testing.T, cmd *exec.Cmd) + }{ + { + name: "resource not found", + cmd: postgresDatabaseCmd{serviceCmd: notFoundCmd}, + wantErr: true, + }, + { + name: "resource not ready", + cmd: postgresDatabaseCmd{serviceCmd: notReadyCmd}, + wantErr: true, + errContains: "not ready", + }, + { + name: "connects without cidr management", + cmd: postgresDatabaseCmd{serviceCmd: connectCmd}, + cap: connectCap, + check: func(t *testing.T, cmd *exec.Cmd) { + t.Helper() + argsStr := strings.Join(cmd.Args, " ") + if !strings.Contains(argsStr, dbFQDN) { + t.Errorf("expected FQDN %q in args %v", dbFQDN, cmd.Args) + } + if !strings.Contains(argsStr, dbName) { + t.Errorf("expected dbname %q in args %v", dbName, cmd.Args) + } + if strings.Contains(argsStr, dbPass) { + t.Errorf("password must not appear in args %v", cmd.Args) + } + if !containsEnv(cmd.Env, "PGPASSWORD="+dbPass) { + t.Errorf("expected PGPASSWORD env var, got %v", cmd.Env) + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + updateCalled := false + apiClient := test.SetupClient(t, + test.WithObjects(ready, notReady, secret), + test.WithInterceptorFuncs(interceptor.Funcs{ + Update: func(ctx context.Context, c runtimeclient.WithWatch, obj runtimeclient.Object, opts ...runtimeclient.UpdateOption) error { + updateCalled = true + return c.Update(ctx, obj, opts...) + }, + }), + ) + + err := tc.cmd.Run(t.Context(), apiClient) + + if (err != nil) != tc.wantErr { + t.Fatalf("Run() error = %v, wantErr %v", err, tc.wantErr) + } + if tc.errContains != "" && (err == nil || !strings.Contains(err.Error(), tc.errContains)) { + t.Errorf("expected error containing %q, got %v", tc.errContains, err) + } + if updateCalled { + t.Error("Update must not be called for child database resources") + } + if !tc.wantErr && tc.check != nil { + tc.check(t, tc.cap.cmd) + } + }) + } +} diff --git a/get/apiserviceaccount.go b/get/apiserviceaccount.go index e107c287..7673a0df 100644 --- a/get/apiserviceaccount.go +++ b/get/apiserviceaccount.go @@ -144,7 +144,7 @@ func (cmd *apiServiceAccountsCmd) printSecret( key string, out *output, ) error { - data, err := getConnectionSecret(ctx, client, key, sa) + data, err := connectionSecret(ctx, client, key, sa) if err != nil { return err } diff --git a/get/bucketuser.go b/get/bucketuser.go index 681fbda7..79b49975 100644 --- a/get/bucketuser.go +++ b/get/bucketuser.go @@ -107,7 +107,7 @@ func (cmd *bucketUserCmd) printSecret( key string, out *output, ) error { - data, err := getConnectionSecret(ctx, client, key, user) + data, err := connectionSecret(ctx, client, key, user) if err != nil { return err } diff --git a/get/database.go b/get/database.go index 67d0fd6c..30cb35a6 100644 --- a/get/database.go +++ b/get/database.go @@ -47,7 +47,7 @@ func (cmd *databaseCmd) run(ctx context.Context, client *api.Client, get *Cmd, } if cmd.Name != "" && cmd.PrintConnectionString { - secrets, err := getConnectionSecretMap(ctx, client, databaseResources.GetItems()[0]) + secrets, err := ConnectionSecretMap(ctx, client, databaseResources.GetItems()[0]) if err != nil { return err } @@ -66,7 +66,7 @@ func (cmd *databaseCmd) run(ctx context.Context, client *api.Client, get *Cmd, if err != nil { return err } - return printBase64(&get.Writer, ca) + return WriteBase64(&get.Writer, ca) } switch get.Format { diff --git a/get/get.go b/get/get.go index e14159a7..8e17a694 100644 --- a/get/get.go +++ b/get/get.go @@ -180,7 +180,7 @@ func (out *output) notFound(kind, project string) error { return err } -func getConnectionSecretMap(ctx context.Context, client *api.Client, mg resource.Managed) (map[string][]byte, error) { +func ConnectionSecretMap(ctx context.Context, client *api.Client, mg resource.Managed) (map[string][]byte, error) { secret, err := client.GetConnectionSecret(ctx, mg) if err != nil { return nil, err @@ -189,8 +189,8 @@ func getConnectionSecretMap(ctx context.Context, client *api.Client, mg resource return secret.Data, nil } -func getConnectionSecret(ctx context.Context, client *api.Client, key string, mg resource.Managed) (string, error) { - secrets, err := getConnectionSecretMap(ctx, client, mg) +func connectionSecret(ctx context.Context, client *api.Client, key string, mg resource.Managed) (string, error) { + secrets, err := ConnectionSecretMap(ctx, client, mg) if err != nil { return "", fmt.Errorf("unable to get connection secret: %w", err) } @@ -210,7 +210,7 @@ func (cmd *resourceCmd) printSecret( out *output, field func(string, string) string, ) error { - secrets, err := getConnectionSecretMap(ctx, client, mg) + secrets, err := ConnectionSecretMap(ctx, client, mg) if err != nil { return err } @@ -229,7 +229,7 @@ func (cmd *resourceCmd) printCredentials( out *output, filter func(key string) bool, ) error { - data, err := getConnectionSecretMap(ctx, client, mg) + data, err := ConnectionSecretMap(ctx, client, mg) if err != nil { return err } @@ -266,7 +266,7 @@ func (cmd *resourceCmd) printCredentials( return nil } -func printBase64(out io.Writer, s string) error { +func WriteBase64(out io.Writer, s string) error { s = strings.TrimSpace(s) if s == "" { return nil diff --git a/get/keyvaluestore.go b/get/keyvaluestore.go index a541a154..59bd1f22 100644 --- a/get/keyvaluestore.go +++ b/get/keyvaluestore.go @@ -37,7 +37,7 @@ func (cmd *keyValueStoreCmd) print(ctx context.Context, client *api.Client, list return cmd.printSecret(ctx, client, &keyValueStoreList.Items[0], out, func(_, pw string) string { return pw }) } if cmd.Name != "" && cmd.PrintCACert { - return printBase64(&out.Writer, keyValueStoreList.Items[0].Status.AtProvider.CACert) + return WriteBase64(&out.Writer, keyValueStoreList.Items[0].Status.AtProvider.CACert) } switch out.Format { diff --git a/get/opensearch.go b/get/opensearch.go index 273a3aa1..1faa6772 100644 --- a/get/opensearch.go +++ b/get/opensearch.go @@ -63,7 +63,7 @@ func (cmd *openSearchCmd) print( } if cmd.Name != "" && cmd.PrintCACert { - return printBase64(&out.Writer, openSearchList.Items[0].Status.AtProvider.CACert) + return WriteBase64(&out.Writer, openSearchList.Items[0].Status.AtProvider.CACert) } if cmd.Name != "" && cmd.PrintSnapshotBucket { diff --git a/get/postgres.go b/get/postgres.go index 7ebe0885..bea1dbab 100644 --- a/get/postgres.go +++ b/get/postgres.go @@ -65,7 +65,7 @@ func (cmd *postgresCmd) connectionString(mg resource.Managed, secrets map[string } for user, pw := range secrets { - return postgresConnectionString(my.Status.AtProvider.FQDN, user, "postgres", pw), nil + return PostgresConnectionString(my.Status.AtProvider.FQDN, user, "postgres", pw).String(), nil } return "", nil diff --git a/get/postgresdatabase.go b/get/postgresdatabase.go index 04c22099..cd0dbc9e 100644 --- a/get/postgresdatabase.go +++ b/get/postgresdatabase.go @@ -66,21 +66,25 @@ func (cmd *postgresDatabaseCmd) connectionString(mg resource.Managed, secrets ma } for user, pw := range secrets { - return postgresConnectionString(my.Status.AtProvider.FQDN, user, user, pw), nil + return PostgresConnectionString(my.Status.AtProvider.FQDN, user, user, pw).String(), nil } return "", nil } -// postgresConnectionString according to the PostgreSQL documentation: +// PostgresConnectionString according to the PostgreSQL documentation: // https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING -func postgresConnectionString(fqdn, user, db string, pw []byte) string { +func PostgresConnectionString(fqdn, user, db string, pw []byte) *url.URL { + q := url.Values{} + q.Set("sslmode", "require") + u := &url.URL{ - Scheme: "postgres", - Host: fqdn, - User: url.UserPassword(user, string(pw)), - Path: db, + Scheme: "postgres", + Host: fqdn, + User: url.UserPassword(user, string(pw)), + Path: db, + RawQuery: q.Encode(), } - return u.String() + return u } diff --git a/go.mod b/go.mod index ef13efdb..52095151 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,7 @@ require ( github.com/jotaen/kong-completion v0.0.11 github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de github.com/lucasepe/codename v0.2.1-0.20230220151621-5e31bf1e775f - github.com/mattn/go-isatty v0.0.22 + github.com/mattn/go-isatty v0.0.20 github.com/moby/moby v28.5.2+incompatible github.com/moby/term v0.5.2 github.com/ninech/apis v0.0.0-20260420170138-f082e6318aed @@ -34,6 +34,7 @@ require ( golang.org/x/crypto v0.49.0 golang.org/x/oauth2 v0.36.0 golang.org/x/sync v0.20.0 + golang.org/x/term v0.42.0 k8s.io/api v0.36.0 k8s.io/apimachinery v0.36.0 k8s.io/client-go v0.36.0 @@ -327,7 +328,6 @@ require ( golang.org/x/mod v0.34.0 // indirect golang.org/x/net v0.52.0 // indirect golang.org/x/sys v0.43.0 // indirect - golang.org/x/term v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect golang.org/x/time v0.14.0 // indirect golang.org/x/tools v0.43.0 // indirect diff --git a/go.sum b/go.sum index d9ce7136..cb9b0405 100644 --- a/go.sum +++ b/go.sum @@ -530,8 +530,8 @@ github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHP github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= -github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4= -github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= @@ -972,6 +972,7 @@ golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= diff --git a/internal/format/print.go b/internal/format/print.go index baf5abe4..4413d9dc 100644 --- a/internal/format/print.go +++ b/internal/format/print.go @@ -14,7 +14,8 @@ import ( "github.com/fatih/color" "github.com/goccy/go-yaml/lexer" "github.com/goccy/go-yaml/printer" - "github.com/mattn/go-isatty" + "golang.org/x/term" + "github.com/theckman/yacspin" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" @@ -282,7 +283,7 @@ func IsInteractiveEnvironment(out io.Writer) bool { if !isFile { return false } - return isatty.IsTerminal(f.Fd()) || isatty.IsCygwinTerminal(f.Fd()) + return term.IsTerminal(int(f.Fd())) } // stripObj removes some fields which simply add clutter to the yaml output. diff --git a/internal/ipcheck/client.go b/internal/ipcheck/client.go new file mode 100644 index 00000000..ea352ed0 --- /dev/null +++ b/internal/ipcheck/client.go @@ -0,0 +1,152 @@ +// Package ipcheck provides a client for detecting the caller's public IP address. +package ipcheck + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/netip" + "net/url" + "sync" + "time" +) + +// ErrStatus represents a non success status code error. +var ErrStatus = errors.New("status code error") + +const ( + // defaultTimeout is the default HTTP timeout. + defaultTimeout = 5 * time.Second + // defaultURL is the default endpoint to query. + defaultURL = "https://ip.nine.ch/" +) + +// defaultClient returns the default Client instance. +var defaultClient = sync.OnceValue(func() *Client { + return New() +}) + +// PublicIP returns the caller's public IP address as reported by the endpoint. +func PublicIP(ctx context.Context) (*Response, error) { + return defaultClient().PublicIP(ctx) +} + +// Client fetches the caller's public IP address from Nine's IP check endpoint. +type Client struct { + // httpClient is the HTTP client to use. If nil, a default client with a 5s timeout is used. + httpClient *http.Client + // userAgent is the value to set in the User-Agent header. + userAgent string + // url is the endpoint to query. Defaults to https://ip-ban-check.nine.ch/. + url *url.URL +} + +// Response is the JSON response from the IP check endpoint. +type Response struct { + Blocked bool `json:"blocked"` + RemoteAddr netip.Addr `json:"remoteAddr"` +} + +// Option is a function that configures a Client. +type Option func(*Client) + +// WithHTTPClient configures the HTTP client to use. +func WithHTTPClient(client *http.Client) Option { + return func(c *Client) { + c.httpClient = client + } +} + +// WithUserAgent configures the User-Agent header to use. +func WithUserAgent(userAgent string) Option { + return func(c *Client) { + c.userAgent = userAgent + } +} + +// WithURL configures the endpoint URL to query. +func WithURL(url *url.URL) Option { + return func(c *Client) { + c.url = url + } +} + +// New creates a new Client with the given options. +func New(options ...Option) *Client { + u, _ := url.Parse(defaultURL) + c := &Client{ + url: u, + httpClient: &http.Client{Timeout: defaultTimeout}, + } + + for _, opt := range options { + opt(c) + } + + return c +} + +// PublicIP returns the caller's public IP address as reported by the endpoint. +func (c *Client) PublicIP(ctx context.Context) (*Response, error) { + req, err := c.newRequest(ctx, http.MethodGet) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + + result := Response{} + if _, err := c.doJSON(req, &result); err != nil { + return nil, fmt.Errorf("decoding IP check response: %w", err) + } + + return &result, nil +} + +// newRequest creates a new HTTP request with the given method and URL. +func (c *Client) newRequest(ctx context.Context, method string) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, method, c.url.String(), nil) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json") + if c.userAgent != "" { + req.Header.Set("User-Agent", c.userAgent) + } + return req, nil +} + +// doJSON sends the given request and decodes the response into v. +func (c *Client) doJSON(req *http.Request, v any) (*http.Response, error) { + resp, err := c.do(req) + if err != nil { + return resp, err + } + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + err = json.NewDecoder(resp.Body).Decode(&v) + + return resp, err +} + +// do sends the given request and returns the response. +func (c *Client) do(req *http.Request) (*http.Response, error) { + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return resp, fmt.Errorf( + "%s: %d, %w", + http.StatusText(resp.StatusCode), + resp.StatusCode, + ErrStatus, + ) + } + + return resp, err +} diff --git a/internal/ipcheck/client_test.go b/internal/ipcheck/client_test.go new file mode 100644 index 00000000..aeac645a --- /dev/null +++ b/internal/ipcheck/client_test.go @@ -0,0 +1,72 @@ +package ipcheck_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/netip" + "net/url" + "testing" + + "github.com/ninech/nctl/internal/ipcheck" +) + +func TestClient_PublicIP(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + response ipcheck.Response + statusCode int + wantIP netip.Addr + wantErr bool + }{ + { + name: "returns remote addr", + response: ipcheck.Response{Blocked: false, RemoteAddr: netip.MustParseAddr("203.0.113.1")}, + statusCode: http.StatusOK, + wantIP: netip.MustParseAddr("203.0.113.1"), + }, + { + name: "server error", + statusCode: http.StatusInternalServerError, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Accept") != "application/json" { + t.Errorf("expected Accept: application/json, got %q", r.Header.Get("Accept")) + } + w.WriteHeader(tc.statusCode) + if tc.statusCode == http.StatusOK { + _ = json.NewEncoder(w).Encode(tc.response) + } + })) + defer srv.Close() + + srvURL, err := url.Parse(srv.URL) + if err != nil { + t.Fatalf("parsing URL %q: %v", srv.URL, err) + } + + c := ipcheck.New( + ipcheck.WithURL(srvURL), + ipcheck.WithHTTPClient(srv.Client()), + ipcheck.WithUserAgent("nctl-test"), + ) + + got, err := c.PublicIP(t.Context()) + if (err != nil) != tc.wantErr { + t.Fatalf("PublicIP() error = %v, wantErr %v", err, tc.wantErr) + } + if !tc.wantErr && got.RemoteAddr.Compare(tc.wantIP) != 0 { + t.Errorf("PublicIP() = %q, want %q", got.RemoteAddr.String(), tc.wantIP) + } + }) + } +} diff --git a/main.go b/main.go index 30dfe18b..9b2b3da3 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ import ( "github.com/alecthomas/kong" completion "github.com/jotaen/kong-completion" management "github.com/ninech/apis/management/v1alpha1" + storage "github.com/ninech/apis/storage/v1alpha1" "github.com/ninech/nctl/api" "github.com/ninech/nctl/apply" "github.com/ninech/nctl/auth" @@ -206,8 +207,7 @@ func main() { } } - var cliErr *cli.Error - if errors.As(err, &cliErr) { + if cliErr, ok := errors.AsType[*cli.Error](err); ok { fmt.Fprintln(writer, err.Error()) kongCtx.Exit(cliErr.ExitCode()) return @@ -223,6 +223,8 @@ func clientPredictors(ctx context.Context, apiClientRequired bool) []completion. nothing := []completion.Option{ completion.WithPredictor("resource_name", complete.PredictNothing), completion.WithPredictor("project_name", complete.PredictNothing), + completion.WithPredictor("postgres_databases", complete.PredictNothing), + completion.WithPredictor("mysql_databases", complete.PredictNothing), } if !apiClientRequired { @@ -239,6 +241,8 @@ func clientPredictors(ctx context.Context, apiClientRequired bool) []completion. completion.WithPredictor("project_name", predictor.NewResourceNameWithKind(client, management.SchemeGroupVersion.WithKind(reflect.TypeFor[management.ProjectList]().Name())), ), + completion.WithPredictor("postgres_databases", predictor.NewInstanceDatabases(client, storage.PostgresGroupVersionKind)), + completion.WithPredictor("mysql_databases", predictor.NewInstanceDatabases(client, storage.MySQLGroupVersionKind)), } } diff --git a/predictor/predictor.go b/predictor/predictor.go index e820c64b..f7381f30 100644 --- a/predictor/predictor.go +++ b/predictor/predictor.go @@ -14,6 +14,7 @@ import ( "github.com/posener/complete" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -26,6 +27,7 @@ const ( // completion. var argResourceMap = map[string]string{ "clusters": "kubernetesclusters", + "kvs": "keyvaluestore", } // Resource is a predictor that completes resource names by querying the API. @@ -174,6 +176,79 @@ func findProjectInSlice(args []string) string { return "" } +// InstanceDatabases is a predictor that completes database names for a +// dedicated database instance (e.g. Postgres, MySQL) by reading the Databases +// map from its status. +type InstanceDatabases struct { + client *api.Client + gvk schema.GroupVersionKind +} + +// NewInstanceDatabases returns a predictor that completes database names for a +// dedicated instance resource. It fetches the instance named by the first +// positional argument on the command line and returns the keys of its +// status.atProvider.databases map. +func NewInstanceDatabases(client *api.Client, gvk schema.GroupVersionKind) complete.Predictor { + return &InstanceDatabases{client: client, gvk: gvk} +} + +// Predict returns the database names available on the instance whose name +// appears as the first positional argument in the completion context. +func (d *InstanceDatabases) Predict(args complete.Args) []string { + name := firstPositionalArg(args.Completed) + if name == "" { + return nil + } + + p, incomplete := findProject(args) + if incomplete { + return nil + } + + ns := d.client.Project + if p != "" { + ns = p + } + + u := &unstructured.Unstructured{} + u.SetGroupVersionKind(d.gvk) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + if err := d.client.Get(ctx, types.NamespacedName{Name: name, Namespace: ns}, u); err != nil { + return nil + } + + databases, found, err := unstructured.NestedMap(u.Object, "status", "atProvider", "databases") + if err != nil || !found { + return nil + } + + names := make([]string, 0, len(databases)) + for dbName := range databases { + names = append(names, dbName) + } + return names +} + +// firstPositionalArg returns the first positional (non-flag) argument in args, +// skipping flag-value pairs. It assumes all flags that are not of the form +// --flag=value take a separate value token. +func firstPositionalArg(args []string) string { + for i := 0; i < len(args); i++ { + arg := args[i] + if strings.HasPrefix(arg, "-") { + if !strings.Contains(arg, "=") { + i++ // skip the following value token + } + continue + } + return arg + } + return "" +} + // NewClient creates an API client configured for shell completion. It uses a // static token since dynamic exec config breaks with some shells during // completion. diff --git a/predictor/predictor_test.go b/predictor/predictor_test.go index bdf714e1..a083f447 100644 --- a/predictor/predictor_test.go +++ b/predictor/predictor_test.go @@ -2,9 +2,12 @@ package predictor import ( "bytes" + "sort" "strconv" "testing" + storage "github.com/ninech/apis/storage/v1alpha1" + "github.com/ninech/nctl/internal/test" "github.com/posener/complete" ) @@ -217,3 +220,124 @@ func TestFindProjectInSlice(t *testing.T) { }) } } + +func TestFirstPositionalArg(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + args []string + want string + }{ + { + name: "empty", + args: []string{}, + want: "", + }, + { + name: "positional only", + args: []string{"myinstance"}, + want: "myinstance", + }, + { + name: "flag before positional", + args: []string{"-p", "myproject", "myinstance"}, + want: "myinstance", + }, + { + name: "positional then flag", + args: []string{"myinstance", "--database"}, + want: "myinstance", + }, + { + name: "flag equals form skips no value token", + args: []string{"--project=myproject", "myinstance"}, + want: "myinstance", + }, + { + name: "only flags", + args: []string{"-p", "myproject"}, + want: "", + }, + { + name: "dangling flag without value", + args: []string{"--database"}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := firstPositionalArg(tt.args); got != tt.want { + t.Errorf("firstPositionalArg(%v) = %q, want %q", tt.args, got, tt.want) + } + }) + } +} + +func TestInstanceDatabasesPredict(t *testing.T) { + t.Parallel() + + const ( + instanceName = "mypg" + project = test.DefaultProject + location = "nine-es34" + ) + + pg := test.Postgres(instanceName, project, location) + pg.Status.AtProvider.Databases = map[string]storage.DatabaseObservation{ + "appdb": {}, + "otherdb": {}, + "postgres": {}, + } + + client := test.SetupClient(t, + test.WithObjects(pg), + test.WithDefaultProject(project), + ) + + predictor := NewInstanceDatabases(client, storage.PostgresGroupVersionKind) + + tests := []struct { + name string + completed []string + want []string + }{ + { + name: "returns databases for named instance", + completed: []string{instanceName, "--database"}, + want: []string{"appdb", "otherdb", "postgres"}, + }, + { + name: "returns nil when no instance name provided", + completed: []string{"--database"}, + want: nil, + }, + { + name: "returns nil for unknown instance", + completed: []string{"doesnotexist", "--database"}, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := predictor.Predict(complete.Args{Completed: tt.completed}) + sort.Strings(got) + sort.Strings(tt.want) + + if len(got) != len(tt.want) { + t.Fatalf("Predict() = %v, want %v", got, tt.want) + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("Predict()[%d] = %q, want %q", i, got[i], tt.want[i]) + } + } + }) + } +}