diff --git a/experimental/ssh/internal/keys/keys.go b/experimental/ssh/internal/keys/keys.go index 5c835b279f..1fafdef16f 100644 --- a/experimental/ssh/internal/keys/keys.go +++ b/experimental/ssh/internal/keys/keys.go @@ -6,9 +6,7 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" - "errors" "fmt" - "io/fs" "os" "path/filepath" @@ -52,23 +50,15 @@ func generateSSHKeyPair() ([]byte, []byte, error) { } func SaveSSHKeyPair(keyPath string, privateKeyBytes, publicKeyBytes []byte) error { - err := os.RemoveAll(filepath.Dir(keyPath)) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("failed to remove existing key directory: %w", err) - } - if err := os.MkdirAll(filepath.Dir(keyPath), 0o700); err != nil { return fmt.Errorf("failed to create directory for key: %w", err) } - if err := os.WriteFile(keyPath, privateKeyBytes, 0o600); err != nil { return fmt.Errorf("failed to write private key to file: %w", err) } - if err := os.WriteFile(keyPath+".pub", publicKeyBytes, 0o644); err != nil { return fmt.Errorf("failed to write public key to file: %w", err) } - return nil } diff --git a/experimental/ssh/internal/keys/keys_test.go b/experimental/ssh/internal/keys/keys_test.go new file mode 100644 index 0000000000..68054311f8 --- /dev/null +++ b/experimental/ssh/internal/keys/keys_test.go @@ -0,0 +1,89 @@ +package keys_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/databricks/cli/experimental/ssh/internal/keys" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSaveSSHKeyPairNewFiles(t *testing.T) { + dir := t.TempDir() + keyPath := filepath.Join(dir, "session1") + privateKey := []byte("private-key-content") + publicKey := []byte("public-key-content") + + err := keys.SaveSSHKeyPair(keyPath, privateKey, publicKey) + require.NoError(t, err) + + gotPrivate, err := os.ReadFile(keyPath) + require.NoError(t, err) + assert.Equal(t, privateKey, gotPrivate) + + gotPublic, err := os.ReadFile(keyPath + ".pub") + require.NoError(t, err) + assert.Equal(t, publicKey, gotPublic) +} + +func TestSaveSSHKeyPairOverwritesExistingFiles(t *testing.T) { + dir := t.TempDir() + keyPath := filepath.Join(dir, "session1") + + // Write initial keys. + require.NoError(t, keys.SaveSSHKeyPair(keyPath, []byte("old-private"), []byte("old-public"))) + + // Overwrite with new keys. + newPrivate := []byte("new-private-key-content") + newPublic := []byte("new-public-key-content") + err := keys.SaveSSHKeyPair(keyPath, newPrivate, newPublic) + require.NoError(t, err) + + gotPrivate, err := os.ReadFile(keyPath) + require.NoError(t, err) + assert.Equal(t, newPrivate, gotPrivate) + + gotPublic, err := os.ReadFile(keyPath + ".pub") + require.NoError(t, err) + assert.Equal(t, newPublic, gotPublic) +} + +func TestSaveSSHKeyPairCreatesDirectory(t *testing.T) { + dir := t.TempDir() + keyPath := filepath.Join(dir, "nonexistent-subdir", "session1") + privateKey := []byte("private-key-content") + publicKey := []byte("public-key-content") + + err := keys.SaveSSHKeyPair(keyPath, privateKey, publicKey) + require.NoError(t, err) + + gotPrivate, err := os.ReadFile(keyPath) + require.NoError(t, err) + assert.Equal(t, privateKey, gotPrivate) + + gotPublic, err := os.ReadFile(keyPath + ".pub") + require.NoError(t, err) + assert.Equal(t, publicKey, gotPublic) +} + +func TestSaveSSHKeyPairDoesNotAffectOtherSessions(t *testing.T) { + dir := t.TempDir() + keyPath1 := filepath.Join(dir, "session1") + keyPath2 := filepath.Join(dir, "session2") + + require.NoError(t, keys.SaveSSHKeyPair(keyPath1, []byte("private-1"), []byte("public-1"))) + require.NoError(t, keys.SaveSSHKeyPair(keyPath2, []byte("private-2"), []byte("public-2"))) + + // Overwrite session1 — session2 must be untouched. + require.NoError(t, keys.SaveSSHKeyPair(keyPath1, []byte("private-1-new"), []byte("public-1-new"))) + + gotPrivate2, err := os.ReadFile(keyPath2) + require.NoError(t, err) + assert.Equal(t, []byte("private-2"), gotPrivate2) + + gotPublic2, err := os.ReadFile(keyPath2 + ".pub") + require.NoError(t, err) + assert.Equal(t, []byte("public-2"), gotPublic2) +}