diff --git a/CHANGELOG.md b/CHANGELOG.md index 38e1437..a017a9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - VS Code extension snippets for common Testo blocks. +### Fixed + +- Fixed a bug when long cache keys could trigger an error. + ## [1.3.0] - 2026-05-24 ### Added diff --git a/testocache/cache.go b/testocache/cache.go index d7bedb8..0c5ddf4 100644 --- a/testocache/cache.go +++ b/testocache/cache.go @@ -13,16 +13,18 @@ package testocache import ( + "bytes" "cmp" "errors" "flag" - "io/fs" + "hash/fnv" + "io" "os" + "path" "path/filepath" + "slices" "strconv" - "strings" "sync" - "unicode" ) var ( @@ -38,8 +40,17 @@ var ( ) ) -// ErrDisabled indicates that caching is disabled. -var ErrDisabled = errors.New("cache is disabled") +var ( + // ErrDisabled indicates that caching is disabled. + ErrDisabled = errors.New("testocache: cache is disabled") + + // ErrInvalidKey indicates that passed key is invalid. + // Currently, key is invalid if it contains a NUL-byte. + ErrInvalidKey = errors.New("testocache: invalid key") + + // ErrNotFound indicates that value was not found the passed key. + ErrNotFound = errors.New("testocache: not found") +) const ( permFile os.FileMode = 0o600 @@ -56,10 +67,36 @@ func Disabled() bool { var kvMu sync.RWMutex // Keys returns all glob-matched keys by the given pattern. -// E.g. "myplugin-prefix-*" +// +// The pattern syntax is: +// +// pattern: +// { term } +// term: +// '*' matches any sequence of non-/ characters +// '?' matches any single non-/ character +// '[' [ '^' ] { character-range } ']' +// character class (must be non-empty) +// c matches character c (c != '*', '?', '\\', '[') +// '\\' c matches character c +// +// character-range: +// c matches character c (c != '\\', '-', ']') +// '\\' c matches character c +// lo '-' hi matches character c for lo <= c <= hi +// +// Keys requires pattern to match all of name, not just a substring. // // If cache is disabled (see [Disabled]), this function returns [ErrDisabled]. func Keys(pattern string) (keys []string, err error) { + if err := validate(pattern); err != nil { + return nil, err + } + + if _, err := path.Match(pattern, ""); err != nil { + return nil, err + } + dir, err := cacheDir() if err != nil { return nil, err @@ -68,13 +105,71 @@ func Keys(pattern string) (keys []string, err error) { kvMu.RLock() defer kvMu.RUnlock() - return fs.Glob(os.DirFS(dir), pattern) + entries, err := os.ReadDir(dir) + if err != nil { + return nil, err + } + + keys = make([]string, 0, len(keys)) + + for _, e := range entries { + key, err := extractKey(filepath.Join(dir, e.Name())) + if err != nil { + return nil, err + } + + if ok, _ := path.Match(pattern, key); ok { + keys = append(keys, key) + } + } + + return keys, nil +} + +func extractKey(p string) (string, error) { + f, err := os.Open(p) + if err != nil { + return "", err + } + defer f.Close() + + var collected []byte + + // heuristic + chunk := make([]byte, 32) + + for { + n, err := io.ReadAtLeast(f, chunk, 1) + if err != nil { + if errors.Is(err, io.EOF) { + return "", nil + } + + return "", err + } + + before, _, ok := bytes.Cut(chunk[:n], []byte{0}) + if ok { + if len(collected) == 0 { + return string(before), nil + } + + return string(append(collected, before...)), nil + } + + collected = append(collected, before...) + } } // Get cached object by the given key. +// Key must not contain a NUL-byte. // // If cache is disabled (see [Disabled]), this function returns [ErrDisabled]. func Get(key string) ([]byte, error) { + if err := validate(key); err != nil { + return nil, err + } + dir, err := cacheDir() if err != nil { return nil, err @@ -83,15 +178,40 @@ func Get(key string) ([]byte, error) { kvMu.RLock() defer kvMu.RUnlock() - path := filepath.Join(dir, sanitizeFilename(key)) + h, err := hash(key) + if err != nil { + return nil, err + } + + p := filepath.Join(dir, h) + + _, err = os.Stat(p) + if err != nil { + return nil, ErrNotFound + } + + value, err := os.ReadFile(p) + if err != nil { + return nil, err + } + + _, after, ok := bytes.Cut(value, []byte{0}) + if !ok { + return value, nil + } - return os.ReadFile(path) + return after, nil } // Set saves value to cache with the given key. +// Key must not contain a NUL-byte. // // If cache is disabled (see [Disabled]), this function returns [ErrDisabled]. func Set(key string, value []byte) error { + if err := validate(key); err != nil { + return err + } + dir, err := cacheDir() if err != nil { return err @@ -100,15 +220,32 @@ func Set(key string, value []byte) error { kvMu.Lock() defer kvMu.Unlock() - path := filepath.Join(dir, sanitizeFilename(key)) + h, err := hash(key) + if err != nil { + return err + } + + p := filepath.Join(dir, h) + + buf := bytes.NewBufferString(key) + + buf.Grow(1 + len(value)) + + buf.WriteByte(0) + buf.Write(value) - return os.WriteFile(path, value, permFile) + return os.WriteFile(p, buf.Bytes(), permFile) } // Remove object from cache by the given key. +// Key must not contain a NUL-byte. // // If cache is disabled (see [Disabled]), this function returns [ErrDisabled]. func Remove(key string) error { + if err := validate(key); err != nil { + return err + } + dir, err := cacheDir() if err != nil { return err @@ -117,9 +254,14 @@ func Remove(key string) error { kvMu.Lock() defer kvMu.Unlock() - path := filepath.Join(dir, sanitizeFilename(key)) + h, err := hash(key) + if err != nil { + return err + } + + p := filepath.Join(dir, h) - return os.Remove(path) + return os.Remove(p) } func cacheDir() (string, error) { @@ -146,25 +288,21 @@ func parseBool(s string) bool { return b } -func sanitizeFilename(name string) string { - var sb strings.Builder - - sb.Grow(len(name)) +func validate(key string) error { + if slices.Contains([]byte(key), 0) { + return ErrInvalidKey + } - const ( - invalid = `\/<>:\"|?*.` - replacement = '-' - ) + return nil +} - for _, r := range name { - switch { - case r == 0, unicode.IsControl(r), strings.ContainsRune(invalid, r): - sb.WriteRune(replacement) +func hash(key string) (string, error) { + h := fnv.New64a() - default: - sb.WriteRune(r) - } + _, err := h.Write([]byte(key)) + if err != nil { + return "", err } - return sb.String() + return strconv.FormatUint(h.Sum64(), 36), nil } diff --git a/testocache/cache_test.go b/testocache/cache_test.go index 430cf1f..5a0c186 100644 --- a/testocache/cache_test.go +++ b/testocache/cache_test.go @@ -1,11 +1,45 @@ package testocache import ( + "errors" "slices" "testing" ) -func Test(t *testing.T) { +func TestInvalidKey(t *testing.T) { + t.Parallel() + + const invalid = "foo\x00bar" + + t.Run("set", func(t *testing.T) { + t.Parallel() + + err := Set(invalid, []byte("...")) + if !errors.Is(err, ErrInvalidKey) { + t.Fatalf("err is not ErrInvalidKey: %v", err) + } + }) + + t.Run("get", func(t *testing.T) { + t.Parallel() + + _, err := Get(invalid) + if !errors.Is(err, ErrInvalidKey) { + t.Fatalf("err is not ErrInvalidKey: %v", err) + } + }) + + t.Run("remove", func(t *testing.T) { + t.Parallel() + + err := Remove(invalid) + if !errors.Is(err, ErrInvalidKey) { + t.Fatalf("err is not ErrInvalidKey: %v", err) + } + }) +} + +func TestFlow(t *testing.T) { t.Parallel() for _, tt := range []struct { @@ -13,7 +47,7 @@ func Test(t *testing.T) { Value string }{ {Key: "my-key", Value: "lorem ipsum\ndolor sit \t\tamet"}, - {Key: "key/with/slash", Value: "other value"}, + {Key: "key~with~tilde", Value: "other value"}, } { t.Run("with key: "+tt.Key, func(t *testing.T) { err := Set(tt.Key, []byte(tt.Value)) @@ -37,7 +71,9 @@ func Test(t *testing.T) { t.Fatalf("failed to get keys: %v", err) } - wantKeys := []string{"key-with-slash", "my-key"} + slices.Sort(keys) + + wantKeys := []string{"key~with~tilde", "my-key"} if !slices.Equal(keys, wantKeys) { t.Fatalf("keys: want %v, got %v", wantKeys, keys) } @@ -48,4 +84,9 @@ func Test(t *testing.T) { t.Errorf("remove key %q: %v", k, err) } } + + _, err = Get("unknown-key") + if !errors.Is(err, ErrNotFound) { + t.Fatal("expected not found error") + } }