diff --git a/VERSION b/VERSION index 3d22ace4..a822eaa9 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.0.40 +2.0.45 diff --git a/buildinfo/version.go b/buildinfo/version.go index 9618ef22..d1ebf27d 100644 --- a/buildinfo/version.go +++ b/buildinfo/version.go @@ -9,8 +9,9 @@ var buildInfo *BuildInfo // App identity variables. Defaults are prod values; overridden in main.go for dev builds. var ( - AppName = "deepsource" // binary name / display name - ConfigDirName = ".deepsource" // ~// + AppName = "deepsource" // binary name / display name + ConfigDirName = ".deepsource" // ~// + BaseURL = "https://cli.deepsource.com" // CDN base for manifest and archives ) // BuildInfo describes the compile time information. diff --git a/cmd/deepsource/main.go b/cmd/deepsource/main.go index 71fd81b4..00be64b5 100644 --- a/cmd/deepsource/main.go +++ b/cmd/deepsource/main.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "log" + "net/http" "os" "strings" "time" @@ -11,7 +12,9 @@ import ( v "github.com/deepsourcelabs/cli/buildinfo" "github.com/deepsourcelabs/cli/command" "github.com/deepsourcelabs/cli/internal/cli/style" + "github.com/deepsourcelabs/cli/internal/debug" clierrors "github.com/deepsourcelabs/cli/internal/errors" + "github.com/deepsourcelabs/cli/internal/update" "github.com/getsentry/sentry-go" ) @@ -40,6 +43,7 @@ func mainRun() (exitCode int) { if buildMode == "dev" { v.AppName = "deepsource-dev" v.ConfigDirName = ".deepsource-dev" + v.BaseURL = "https://cli.deepsource.one" } // Init sentry @@ -67,6 +71,32 @@ func mainRun() (exitCode int) { func run() int { v.SetBuildInfo(version, Date, buildMode) + // Two-phase auto-update: apply pending update or check for new one + if update.ShouldAutoUpdate() { + state, err := update.ReadUpdateState() + if err != nil { + debug.Log("update: %v", err) + } + + if state != nil { + // Phase 2: a previous run found a newer version — apply it now + client := &http.Client{Timeout: 30 * time.Second} + newVer, err := update.ApplyUpdate(client) + if err != nil { + debug.Log("update: %v", err) + } else if newVer != "" { + fmt.Fprintf(os.Stderr, "%s\n", style.Yellow("Updated DeepSource CLI to v%s", newVer)) + } + } else { + // Phase 1: check manifest and write state file for next run + client := &http.Client{Timeout: 3 * time.Second} + if err := update.CheckForUpdate(client); err != nil { + debug.Log("update: %v", err) + } + } + } + + exitCode := 0 if err := command.Execute(); err != nil { var cliErr *clierrors.CLIError if errors.As(err, &cliErr) { @@ -78,7 +108,8 @@ func run() int { sentry.CaptureException(err) } sentry.Flush(2 * time.Second) - return 1 + exitCode = 1 } - return 0 + + return exitCode } diff --git a/config/config.go b/config/config.go index 27195726..f5bc37ca 100644 --- a/config/config.go +++ b/config/config.go @@ -16,6 +16,7 @@ type CLIConfig struct { User string `toml:"user"` Token string `toml:"token"` TokenExpiresIn time.Time `toml:"token_expires_in,omitempty"` + AutoUpdate *bool `toml:"auto_update,omitempty"` TokenFromEnv bool `toml:"-"` } diff --git a/internal/update/manifest.go b/internal/update/manifest.go new file mode 100644 index 00000000..5bceef0a --- /dev/null +++ b/internal/update/manifest.go @@ -0,0 +1,53 @@ +package update + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "runtime" + + "github.com/deepsourcelabs/cli/buildinfo" +) + +// Manifest represents the CLI release manifest served by the CDN. +type Manifest struct { + Version string `json:"version"` + BuildTime string `json:"buildTime"` + Platforms map[string]PlatformInfo `json:"platforms"` +} + +// PlatformInfo holds the archive filename and checksum for a platform. +type PlatformInfo struct { + Archive string `json:"archive"` + SHA256 string `json:"sha256"` +} + +// FetchManifest downloads and parses the release manifest. +func FetchManifest(client *http.Client) (*Manifest, error) { + resp, err := client.Get(buildinfo.BaseURL + "/manifest.json") + if err != nil { + return nil, fmt.Errorf("fetching manifest: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("manifest returned HTTP %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading manifest body: %w", err) + } + + var m Manifest + if err := json.Unmarshal(body, &m); err != nil { + return nil, fmt.Errorf("parsing manifest JSON: %w", err) + } + return &m, nil +} + +// PlatformKey returns the manifest key for the current OS/arch. +func PlatformKey() string { + return runtime.GOOS + "_" + runtime.GOARCH +} diff --git a/internal/update/updater.go b/internal/update/updater.go new file mode 100644 index 00000000..e23297e4 --- /dev/null +++ b/internal/update/updater.go @@ -0,0 +1,333 @@ +package update + +import ( + "archive/tar" + "archive/zip" + "bytes" + "compress/gzip" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "strings" + "time" + + "github.com/deepsourcelabs/cli/buildinfo" + "github.com/deepsourcelabs/cli/config" + "github.com/deepsourcelabs/cli/internal/debug" +) + +// UpdateState is the on-disk state written by CheckForUpdate and consumed by ApplyUpdate. +type UpdateState struct { + Version string `json:"version"` + ArchiveURL string `json:"archive_url"` + SHA256 string `json:"sha256"` + CheckedAt time.Time `json:"checked_at"` +} + +// updateStatePath returns the path to the update state file (~/.deepsource/update.json). +func updateStatePath() string { + home, _ := os.UserHomeDir() + return filepath.Join(home, buildinfo.ConfigDirName, "update.json") +} + +// ReadUpdateState reads the update state file. Returns nil if the file does not exist. +func ReadUpdateState() (*UpdateState, error) { + data, err := os.ReadFile(updateStatePath()) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + return nil, fmt.Errorf("reading update state: %w", err) + } + var s UpdateState + if err := json.Unmarshal(data, &s); err != nil { + return nil, fmt.Errorf("parsing update state: %w", err) + } + return &s, nil +} + +func writeUpdateState(s *UpdateState) error { + data, err := json.MarshalIndent(s, "", " ") + if err != nil { + return fmt.Errorf("marshaling update state: %w", err) + } + p := updateStatePath() + if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil { + return fmt.Errorf("creating config dir: %w", err) + } + if err := os.WriteFile(p, data, 0o644); err != nil { + return fmt.Errorf("writing update state: %w", err) + } + return nil +} + +func clearUpdateState() { + _ = os.Remove(updateStatePath()) +} + +// CheckForUpdate fetches the manifest, compares versions, and writes a state +// file if a newer version is available. This is meant to be fast (~100-200ms). +func CheckForUpdate(client *http.Client) error { + bi := buildinfo.GetBuildInfo() + if bi == nil { + return fmt.Errorf("build info not set") + } + + manifest, err := FetchManifest(client) + if err != nil { + return err + } + + newer, err := IsNewer(bi.Version, manifest.Version) + if err != nil { + return err + } + if !newer { + debug.Log("update: already up to date (current=%s, remote=%s)", bi.Version, manifest.Version) + return nil + } + + key := PlatformKey() + platform, ok := manifest.Platforms[key] + if !ok { + return fmt.Errorf("no release for platform %s", key) + } + + state := &UpdateState{ + Version: manifest.Version, + ArchiveURL: buildinfo.BaseURL + "/" + platform.Archive, + SHA256: platform.SHA256, + CheckedAt: time.Now().UTC(), + } + + debug.Log("update: newer version %s available, writing state file", manifest.Version) + return writeUpdateState(state) +} + +// ApplyUpdate reads the state file, downloads the archive, verifies, extracts, +// and replaces the binary. Returns the new version string on success. +// Clears the state file regardless of outcome so we don't retry broken updates forever. +func ApplyUpdate(client *http.Client) (string, error) { + state, err := ReadUpdateState() + if err != nil { + clearUpdateState() + return "", err + } + if state == nil { + return "", nil + } + + // Clear state file up front so a failed update doesn't retry forever. + // The next run will do a fresh CheckForUpdate instead. + clearUpdateState() + + debug.Log("update: applying update to v%s", state.Version) + + data, err := downloadFile(client, state.ArchiveURL) + if err != nil { + return "", err + } + + if err := verifyChecksum(data, state.SHA256); err != nil { + return "", err + } + + binaryName := buildinfo.AppName + if runtime.GOOS == "windows" { + binaryName += ".exe" + } + + var binaryData []byte + if strings.HasSuffix(state.ArchiveURL, ".zip") { + binaryData, err = extractFromZip(data, binaryName) + } else { + binaryData, err = extractFromTarGz(data, binaryName) + } + if err != nil { + return "", err + } + + if err := replaceBinary(binaryData); err != nil { + return "", err + } + + debug.Log("update: updated to v%s", state.Version) + return state.Version, nil +} + +// ShouldAutoUpdate reports whether the auto-updater should run. +func ShouldAutoUpdate() bool { + bi := buildinfo.GetBuildInfo() + if bi == nil { + return false + } + + // Skip local dev builds (go run / go build without ldflags) + if bi.Version == "development" { + debug.Log("update: skipping (local dev build)") + return false + } + + // Skip in CI environments + ciVars := []string{ + "CI", "GITHUB_ACTIONS", "GITLAB_CI", "CIRCLECI", + "TRAVIS", "JENKINS_URL", "BUILDKITE", "TF_BUILD", + } + for _, v := range ciVars { + if os.Getenv(v) != "" { + debug.Log("update: skipping (CI detected via %s)", v) + return false + } + } + + // Check config + cfg, err := config.GetConfig() + if err == nil && cfg.AutoUpdate != nil && !*cfg.AutoUpdate { + debug.Log("update: skipping (disabled in config)") + return false + } + + return true +} + +func downloadFile(client *http.Client, url string) ([]byte, error) { + resp, err := client.Get(url) + if err != nil { + return nil, fmt.Errorf("downloading %s: %w", url, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("download returned HTTP %d", resp.StatusCode) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading download body: %w", err) + } + return data, nil +} + +func verifyChecksum(data []byte, expected string) error { + h := sha256.Sum256(data) + actual := hex.EncodeToString(h[:]) + if actual != expected { + return fmt.Errorf("checksum mismatch: got %s, want %s", actual, expected) + } + debug.Log("update: checksum verified") + return nil +} + +func extractFromTarGz(data []byte, binaryName string) ([]byte, error) { + gz, err := gzip.NewReader(bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("opening gzip: %w", err) + } + defer gz.Close() + + tr := tar.NewReader(gz) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("reading tar: %w", err) + } + if filepath.Base(hdr.Name) == binaryName && hdr.Typeflag == tar.TypeReg { + content, err := io.ReadAll(tr) + if err != nil { + return nil, fmt.Errorf("reading %s from tar: %w", binaryName, err) + } + return content, nil + } + } + return nil, fmt.Errorf("%s not found in archive", binaryName) +} + +func extractFromZip(data []byte, binaryName string) ([]byte, error) { + r, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + if err != nil { + return nil, fmt.Errorf("opening zip: %w", err) + } + for _, f := range r.File { + if filepath.Base(f.Name) == binaryName { + rc, err := f.Open() + if err != nil { + return nil, fmt.Errorf("opening %s in zip: %w", binaryName, err) + } + defer rc.Close() + content, err := io.ReadAll(rc) + if err != nil { + return nil, fmt.Errorf("reading %s from zip: %w", binaryName, err) + } + return content, nil + } + } + return nil, fmt.Errorf("%s not found in archive", binaryName) +} + +// replaceBinary atomically replaces the current executable with newBinary. +func replaceBinary(newBinary []byte) error { + exe, err := os.Executable() + if err != nil { + return fmt.Errorf("finding current executable: %w", err) + } + exe, err = filepath.EvalSymlinks(exe) + if err != nil { + return fmt.Errorf("resolving symlinks: %w", err) + } + + dir := filepath.Dir(exe) + base := filepath.Base(exe) + + // Write new binary to a temp file in the same directory (same filesystem for rename) + tmp, err := os.CreateTemp(dir, base+".new.*") + if err != nil { + return fmt.Errorf("creating temp file: %w", err) + } + tmpPath := tmp.Name() + + if _, err := tmp.Write(newBinary); err != nil { + tmp.Close() + os.Remove(tmpPath) + return fmt.Errorf("writing new binary: %w", err) + } + if err := tmp.Chmod(0o755); err != nil { + tmp.Close() + os.Remove(tmpPath) + return fmt.Errorf("setting permissions: %w", err) + } + if err := tmp.Close(); err != nil { + os.Remove(tmpPath) + return fmt.Errorf("closing temp file: %w", err) + } + + // Rename current binary to .bak, then new to current + bakPath := exe + ".bak" + _ = os.Remove(bakPath) // clean up any leftover .bak + + if err := os.Rename(exe, bakPath); err != nil { + os.Remove(tmpPath) + return fmt.Errorf("backing up current binary: %w", err) + } + + if err := os.Rename(tmpPath, exe); err != nil { + // Try to restore the backup + _ = os.Rename(bakPath, exe) + os.Remove(tmpPath) + return fmt.Errorf("replacing binary: %w", err) + } + + // Clean up backup + _ = os.Remove(bakPath) + + return nil +} diff --git a/internal/update/updater_test.go b/internal/update/updater_test.go new file mode 100644 index 00000000..ea09360c --- /dev/null +++ b/internal/update/updater_test.go @@ -0,0 +1,408 @@ +package update + +import ( + "archive/tar" + "archive/zip" + "bytes" + "compress/gzip" + "crypto/sha256" + "encoding/hex" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/deepsourcelabs/cli/buildinfo" +) + +func TestVerifyChecksum(t *testing.T) { + data := []byte("hello world") + h := sha256.Sum256(data) + good := hex.EncodeToString(h[:]) + + if err := verifyChecksum(data, good); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if err := verifyChecksum(data, "0000000000000000000000000000000000000000000000000000000000000000"); err == nil { + t.Fatal("expected checksum mismatch error") + } +} + +func TestExtractFromTarGz(t *testing.T) { + content := []byte("#!/bin/sh\necho hi\n") + archive := createTarGz(t, "deepsource", content) + + got, err := extractFromTarGz(archive, "deepsource") + if err != nil { + t.Fatalf("extractFromTarGz: %v", err) + } + if !bytes.Equal(got, content) { + t.Errorf("extracted content mismatch: got %q, want %q", got, content) + } +} + +func TestExtractFromTarGz_NotFound(t *testing.T) { + content := []byte("data") + archive := createTarGz(t, "other-binary", content) + + _, err := extractFromTarGz(archive, "deepsource") + if err == nil { + t.Fatal("expected error for missing binary") + } +} + +func TestExtractFromZip(t *testing.T) { + content := []byte("windows binary data") + archive := createZip(t, "deepsource.exe", content) + + got, err := extractFromZip(archive, "deepsource.exe") + if err != nil { + t.Fatalf("extractFromZip: %v", err) + } + if !bytes.Equal(got, content) { + t.Errorf("extracted content mismatch: got %q, want %q", got, content) + } +} + +func TestExtractFromZip_NotFound(t *testing.T) { + content := []byte("data") + archive := createZip(t, "other.exe", content) + + _, err := extractFromZip(archive, "deepsource.exe") + if err == nil { + t.Fatal("expected error for missing binary") + } +} + +func TestReplaceBinary(t *testing.T) { + dir := t.TempDir() + fakeBin := filepath.Join(dir, "deepsource") + if err := os.WriteFile(fakeBin, []byte("old"), 0o755); err != nil { + t.Fatal(err) + } + + // Point os.Executable to our fake binary by using a symlink + // Since we can't override os.Executable, test replaceBinary directly + // by calling the internal logic with a known path. + newContent := []byte("new binary content") + + // Write new binary to temp, rename + tmp, err := os.CreateTemp(dir, "deepsource.new.*") + if err != nil { + t.Fatal(err) + } + if _, err := tmp.Write(newContent); err != nil { + t.Fatal(err) + } + if err := tmp.Chmod(0o755); err != nil { + t.Fatal(err) + } + tmp.Close() + + bakPath := fakeBin + ".bak" + if err := os.Rename(fakeBin, bakPath); err != nil { + t.Fatal(err) + } + if err := os.Rename(tmp.Name(), fakeBin); err != nil { + t.Fatal(err) + } + os.Remove(bakPath) + + got, err := os.ReadFile(fakeBin) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, newContent) { + t.Errorf("binary content mismatch: got %q, want %q", got, newContent) + } +} + +func TestShouldAutoUpdate_DevBuild(t *testing.T) { + buildinfo.SetBuildInfo("2.0.3", "", "dev") + + // Clear CI vars so they don't interfere + ciVars := []string{"CI", "GITHUB_ACTIONS", "GITLAB_CI", "CIRCLECI", "TRAVIS", "JENKINS_URL", "BUILDKITE", "TF_BUILD"} + for _, v := range ciVars { + t.Setenv(v, "") + } + + if !ShouldAutoUpdate() { + t.Error("expected true for dev build with real version") + } +} + +func TestShouldAutoUpdate_DevelopmentVersion(t *testing.T) { + buildinfo.SetBuildInfo("development", "", "") + if ShouldAutoUpdate() { + t.Error("expected false for development version") + } +} + +func TestShouldAutoUpdate_CI(t *testing.T) { + buildinfo.SetBuildInfo("2.0.3", "", "prod") + t.Setenv("CI", "true") + if ShouldAutoUpdate() { + t.Error("expected false in CI") + } +} + +func TestShouldAutoUpdate_Prod(t *testing.T) { + buildinfo.SetBuildInfo("2.0.3", "", "prod") + + // Clear CI vars + ciVars := []string{"CI", "GITHUB_ACTIONS", "GITLAB_CI", "CIRCLECI", "TRAVIS", "JENKINS_URL", "BUILDKITE", "TF_BUILD"} + for _, v := range ciVars { + t.Setenv(v, "") + } + + if !ShouldAutoUpdate() { + t.Error("expected true for prod build outside CI") + } +} + +func TestUpdateState_WriteReadClear(t *testing.T) { + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + + state := &UpdateState{ + Version: "2.0.40", + ArchiveURL: "https://cli.deepsource.com/deepsource_2.0.40_darwin_arm64.tar.gz", + SHA256: "d1717cf33a200d143995c63be28661ed6d21c1380874f3057d3f25f6d9e2b99a", + CheckedAt: time.Date(2026, 3, 1, 20, 0, 0, 0, time.UTC), + } + + if err := writeUpdateState(state); err != nil { + t.Fatalf("writeUpdateState: %v", err) + } + + got, err := ReadUpdateState() + if err != nil { + t.Fatalf("ReadUpdateState: %v", err) + } + if got == nil { + t.Fatal("expected non-nil state") + } + if got.Version != state.Version { + t.Errorf("version: got %q, want %q", got.Version, state.Version) + } + if got.ArchiveURL != state.ArchiveURL { + t.Errorf("archive_url: got %q, want %q", got.ArchiveURL, state.ArchiveURL) + } + if got.SHA256 != state.SHA256 { + t.Errorf("sha256: got %q, want %q", got.SHA256, state.SHA256) + } + if !got.CheckedAt.Equal(state.CheckedAt) { + t.Errorf("checked_at: got %v, want %v", got.CheckedAt, state.CheckedAt) + } + + clearUpdateState() + + got, err = ReadUpdateState() + if err != nil { + t.Fatalf("ReadUpdateState after clear: %v", err) + } + if got != nil { + t.Error("expected nil state after clear") + } +} + +func TestReadUpdateState_NoFile(t *testing.T) { + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + + got, err := ReadUpdateState() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != nil { + t.Error("expected nil state when no file exists") + } +} + +func TestCheckForUpdate_NewerVersion(t *testing.T) { + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + buildinfo.SetBuildInfo("2.0.30", "", "prod") + + key := runtime.GOOS + "_" + runtime.GOARCH + manifest := Manifest{ + Version: "2.0.40", + Platforms: map[string]PlatformInfo{ + key: { + Archive: "deepsource_2.0.40_" + key + ".tar.gz", + SHA256: "abc123", + }, + }, + } + // Simulate what CheckForUpdate does: fetch manifest, compare, write state + newer, _ := IsNewer("2.0.30", manifest.Version) + if !newer { + t.Fatal("expected newer=true") + } + + platform := manifest.Platforms[key] + state := &UpdateState{ + Version: manifest.Version, + ArchiveURL: buildinfo.BaseURL + "/" + platform.Archive, + SHA256: platform.SHA256, + CheckedAt: time.Now().UTC(), + } + if err := writeUpdateState(state); err != nil { + t.Fatalf("writeUpdateState: %v", err) + } + + got, err := ReadUpdateState() + if err != nil { + t.Fatalf("ReadUpdateState: %v", err) + } + if got.Version != "2.0.40" { + t.Errorf("expected version 2.0.40, got %s", got.Version) + } + expectedURL := fmt.Sprintf("%s/deepsource_2.0.40_%s.tar.gz", buildinfo.BaseURL, key) + if got.ArchiveURL != expectedURL { + t.Errorf("expected archive URL %s, got %s", expectedURL, got.ArchiveURL) + } +} + +func TestCheckForUpdate_AlreadyUpToDate(t *testing.T) { + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + buildinfo.SetBuildInfo("2.0.40", "", "prod") + + // Same version — no state file should be written + newer, _ := IsNewer("2.0.40", "2.0.40") + if newer { + t.Fatal("expected newer=false for same version") + } + + got, err := ReadUpdateState() + if err != nil { + t.Fatalf("ReadUpdateState: %v", err) + } + if got != nil { + t.Error("expected no state file for up-to-date version") + } +} + +func TestApplyUpdate_WithStateFile(t *testing.T) { + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + buildinfo.SetBuildInfo("2.0.30", "", "prod") + + binaryContent := []byte("new deepsource binary") + archive := createTarGz(t, buildinfo.AppName, binaryContent) + checksum := sha256.Sum256(archive) + checksumHex := hex.EncodeToString(checksum[:]) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write(archive) + })) + defer srv.Close() + + state := &UpdateState{ + Version: "2.0.40", + ArchiveURL: srv.URL + "/deepsource_2.0.40.tar.gz", + SHA256: checksumHex, + CheckedAt: time.Now().UTC(), + } + if err := writeUpdateState(state); err != nil { + t.Fatalf("writeUpdateState: %v", err) + } + + client := srv.Client() + + // ApplyUpdate reads the state file internally, so we test it reads correctly. + // However, replaceBinary will use os.Executable() which we can't easily mock. + // So we test the pieces: state file reading, download, checksum verification. + + readState, err := ReadUpdateState() + if err != nil { + t.Fatalf("ReadUpdateState: %v", err) + } + if readState.Version != "2.0.40" { + t.Fatalf("expected version 2.0.40, got %s", readState.Version) + } + + data, err := downloadFile(client, readState.ArchiveURL) + if err != nil { + t.Fatalf("downloadFile: %v", err) + } + + if err := verifyChecksum(data, readState.SHA256); err != nil { + t.Fatalf("verifyChecksum: %v", err) + } + + extracted, err := extractFromTarGz(data, buildinfo.AppName) + if err != nil { + t.Fatalf("extractFromTarGz: %v", err) + } + if !bytes.Equal(extracted, binaryContent) { + t.Errorf("extracted binary mismatch: got %q, want %q", extracted, binaryContent) + } + + // Verify clearUpdateState removes the file + clearUpdateState() + afterClear, _ := ReadUpdateState() + if afterClear != nil { + t.Error("state file should be removed after clear") + } +} + +func TestApplyUpdate_NoStateFile(t *testing.T) { + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + + client := &http.Client{Timeout: 1 * time.Second} + ver, err := ApplyUpdate(client) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ver != "" { + t.Errorf("expected empty version, got %q", ver) + } +} + +// helpers + +func createTarGz(t *testing.T, name string, content []byte) []byte { + t.Helper() + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + + hdr := &tar.Header{ + Name: name, + Size: int64(len(content)), + Mode: 0o755, + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatal(err) + } + if _, err := tw.Write(content); err != nil { + t.Fatal(err) + } + tw.Close() + gw.Close() + return buf.Bytes() +} + +func createZip(t *testing.T, name string, content []byte) []byte { + t.Helper() + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + + fw, err := zw.Create(name) + if err != nil { + t.Fatal(err) + } + if _, err := fw.Write(content); err != nil { + t.Fatal(err) + } + zw.Close() + return buf.Bytes() +} diff --git a/internal/update/version.go b/internal/update/version.go new file mode 100644 index 00000000..697421ad --- /dev/null +++ b/internal/update/version.go @@ -0,0 +1,51 @@ +package update + +import ( + "fmt" + "strconv" + "strings" +) + +// IsNewer reports whether remote is a newer semver than current. +// Both values may optionally have a "v" prefix (e.g. "v2.0.3"). +func IsNewer(current, remote string) (bool, error) { + curMaj, curMin, curPatch, err := parseSemver(current) + if err != nil { + return false, fmt.Errorf("parsing current version %q: %w", current, err) + } + remMaj, remMin, remPatch, err := parseSemver(remote) + if err != nil { + return false, fmt.Errorf("parsing remote version %q: %w", remote, err) + } + + if remMaj != curMaj { + return remMaj > curMaj, nil + } + if remMin != curMin { + return remMin > curMin, nil + } + return remPatch > curPatch, nil +} + +func parseSemver(v string) (major, minor, patch int, err error) { + v = strings.TrimPrefix(v, "v") + parts := strings.SplitN(v, ".", 3) + if len(parts) != 3 { + return 0, 0, 0, fmt.Errorf("expected X.Y.Z, got %q", v) + } + major, err = strconv.Atoi(parts[0]) + if err != nil { + return 0, 0, 0, err + } + minor, err = strconv.Atoi(parts[1]) + if err != nil { + return 0, 0, 0, err + } + // Strip pre-release or build metadata suffix (e.g. "44-e888cf0f" → "44") + patchStr, _, _ := strings.Cut(parts[2], "-") + patch, err = strconv.Atoi(patchStr) + if err != nil { + return 0, 0, 0, err + } + return major, minor, patch, nil +} diff --git a/internal/update/version_test.go b/internal/update/version_test.go new file mode 100644 index 00000000..63881431 --- /dev/null +++ b/internal/update/version_test.go @@ -0,0 +1,39 @@ +package update + +import "testing" + +func TestIsNewer(t *testing.T) { + tests := []struct { + name string + current string + remote string + want bool + wantErr bool + }{ + {"same version", "2.0.3", "2.0.3", false, false}, + {"patch bump", "2.0.3", "2.0.4", true, false}, + {"minor bump", "2.0.3", "2.1.0", true, false}, + {"major bump", "2.0.3", "3.0.0", true, false}, + {"older patch", "2.0.4", "2.0.3", false, false}, + {"older minor", "2.1.0", "2.0.9", false, false}, + {"older major", "3.0.0", "2.9.9", false, false}, + {"v prefix current", "v2.0.3", "2.0.4", true, false}, + {"v prefix remote", "2.0.3", "v2.0.4", true, false}, + {"v prefix both", "v2.0.3", "v2.0.4", true, false}, + {"invalid current", "abc", "2.0.4", false, true}, + {"invalid remote", "2.0.3", "xyz", false, true}, + {"two parts", "2.0", "2.0.1", false, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := IsNewer(tt.current, tt.remote) + if (err != nil) != tt.wantErr { + t.Errorf("IsNewer() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("IsNewer(%q, %q) = %v, want %v", tt.current, tt.remote, got, tt.want) + } + }) + } +}