diff --git a/pyproject.toml b/pyproject.toml index 3bf11124..738478ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,9 @@ sdc-agents = [ "sdc-agents>=4.3.3; python_version >= '3.11'", ] spraay = ["web3>=6.0.0"] +valkey = [ + "valkey-glide>=2.4.0", +] [tool.pyink] diff --git a/src/google/adk_community/memory/README.md b/src/google/adk_community/memory/README.md new file mode 100644 index 00000000..25f4bd45 --- /dev/null +++ b/src/google/adk_community/memory/README.md @@ -0,0 +1,133 @@ +# Memory Services + +Community-contributed memory service implementations for the +[Google ADK](https://google.github.io/adk-docs/) framework. + +## Available Services + +### ValkeyMemoryService + +A memory service backed by [Valkey](https://valkey.io/) using the +[Valkey Search module](https://valkey.io/topics/search/) for vector +similarity search. Uses the [valkey-glide](https://github.com/valkey-io/valkey-glide) +client library. This provides functionality analogous to +`VertexAiRagMemoryService` for users with Valkey infrastructure. + +**Features:** +- Vector similarity search (HNSW) powered by the Valkey Search module +- Configurable embedding function (bring your own: OpenAI, Gemini, sentence-transformers, etc.) +- KNN search with pre-filtering by `app_name` and `user_id` TAG fields +- Configurable `vector_distance_threshold` for filtering low-quality matches +- Configurable distance metric (COSINE, L2, IP) +- Optional TTL for automatic memory expiration +- Batch embedding generation for efficient ingestion +- Supports `add_session_to_memory` and `add_events_to_memory` + +**Requirements:** +- Valkey server with the Search module loaded (e.g., + [valkey-bundle](https://hub.docker.com/r/valkey/valkey-bundle) image) +- `valkey-glide >= 2.4.0` +- An embedding function (async callable) + +**Installation:** + +```bash +pip install google-adk-community[valkey] +``` + +**Usage:** + +```python +from glide import GlideClient, GlideClientConfiguration, NodeAddress +from google.adk_community.memory import ValkeyMemoryService, ValkeyMemoryServiceConfig + +# 1. Create a valkey-glide client +# IMPORTANT: Set client_name for observability — it appears in CLIENT LIST, +# monitoring dashboards, and CloudWatch metrics. +config = GlideClientConfiguration( + addresses=[NodeAddress(host="localhost", port=6379)], + client_name="my_adk_memory_client", +) +client = await GlideClient.create(config) + +# 2. Define your embedding function (bring your own model) +# Example with Google Gemini: +from google import genai +genai_client = genai.Client() + +async def embed_texts(texts: list[str]) -> list[list[float]]: + response = await genai_client.models.embed_content_async( + model="text-embedding-004", + contents=texts, + ) + return [e.values for e in response.embeddings] + +# 3. Create the memory service +memory_config = ValkeyMemoryServiceConfig( + similarity_top_k=10, # Max results per search (KNN) + vector_distance_threshold=0.6, # Filter distant results (optional) + embedding_dimensions=768, # Must match your embedding model + key_prefix="adk:memory", # Valkey key prefix + index_name="adk_memory_idx", # Search index name + distance_metric="COSINE", # COSINE, L2, or IP + ttl_seconds=None, # Optional TTL (None = no expiry) +) +memory_service = ValkeyMemoryService( + client=client, + embedding_function=embed_texts, + config=memory_config, +) + +# The index is created automatically on first use, or explicitly: +await memory_service.create_index() + +# 4. Use with ADK Runner +from google.adk.runners import Runner + +runner = Runner( + agent=my_agent, + memory_service=memory_service, + ... +) +``` + +**How it works:** + +1. When `add_session_to_memory` is called, text is extracted from session + events, embeddings are generated in batch using your embedding function, + and each event is stored as a Valkey Hash with the embedding vector. + +2. When `search_memory` is called, an embedding is generated for the query, + then `FT.SEARCH` performs a KNN search with pre-filtering by `app_name` + and `user_id` TAG fields. Results are returned ranked by vector similarity. + +**Running Valkey with Search module:** + +```bash +# Using podman +podman run -d --name valkey -p 6379:6379 valkey/valkey-bundle:9.1 + +# Using docker +docker run -d --name valkey -p 6379:6379 valkey/valkey-bundle:9.1 +``` + +**Note on Redis Session Service:** The existing `RedisSessionService` in this +repo is wire-protocol compatible with Valkey. You can point it at a Valkey +instance directly for session storage without needing a separate Valkey +session service. + +--- + +### OpenMemoryService + +A memory service backed by [OpenMemory](https://openmemory.cavira.app/). +Uses HTTP API calls for memory storage and retrieval with LLM-powered +memory extraction. + +**Installation:** + +```bash +pip install google-adk-community +``` + +See the `OpenMemoryService` class documentation for usage details. diff --git a/src/google/adk_community/memory/__init__.py b/src/google/adk_community/memory/__init__.py index 1f3442c0..f029ec77 100644 --- a/src/google/adk_community/memory/__init__.py +++ b/src/google/adk_community/memory/__init__.py @@ -16,9 +16,13 @@ from .open_memory_service import OpenMemoryService from .open_memory_service import OpenMemoryServiceConfig +from .valkey_memory_service import ValkeyMemoryService +from .valkey_memory_service import ValkeyMemoryServiceConfig __all__ = [ "OpenMemoryService", "OpenMemoryServiceConfig", + "ValkeyMemoryService", + "ValkeyMemoryServiceConfig", ] diff --git a/src/google/adk_community/memory/valkey_memory_service.py b/src/google/adk_community/memory/valkey_memory_service.py new file mode 100644 index 00000000..6cc155a7 --- /dev/null +++ b/src/google/adk_community/memory/valkey_memory_service.py @@ -0,0 +1,539 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Valkey-backed memory service for ADK using valkey-glide client. + +Uses the Valkey Search module with vector similarity search (HNSW) +for semantic memory retrieval, analogous to VertexAiRagMemoryService. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable +from collections.abc import Callable +from collections.abc import Sequence +import logging +import struct +import time +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union +import uuid + +from google.adk.memory.base_memory_service import BaseMemoryService +from google.adk.memory.base_memory_service import SearchMemoryResponse +from google.adk.memory.memory_entry import MemoryEntry +from google.genai import types +from pydantic import BaseModel +from pydantic import Field +from pydantic import field_validator +from typing_extensions import override + +from .utils import extract_text_from_event + +if TYPE_CHECKING: + from glide import GlideClient + from glide import GlideClusterClient + from google.adk.sessions.session import Session + +logger = logging.getLogger("google_adk." + __name__) + +# Type alias for the embedding function. +# It takes a list of text strings and returns a list of float vectors. +EmbeddingFunction = Callable[[list[str]], Awaitable[list[list[float]]]] + + +class ValkeyMemoryServiceConfig(BaseModel): + """Configuration for ValkeyMemoryService. + + Attributes: + similarity_top_k: Maximum number of memories to retrieve per + search (KNN parameter). + vector_distance_threshold: Maximum distance threshold for + filtering results. Results with distance greater than this + are excluded. None means no threshold filtering. + embedding_dimensions: Dimensionality of the embedding vectors. + key_prefix: Prefix for all Valkey keys to avoid collisions. + index_name: Name of the Valkey Search index. + distance_metric: Distance metric for vector similarity. + One of 'COSINE', 'L2', or 'IP' (inner product). + ttl_seconds: Optional TTL for memory entries in seconds. + None means no expiration. + """ + + similarity_top_k: int = Field(default=10, ge=1, le=1000) + vector_distance_threshold: Optional[float] = Field(default=None, ge=0.0) + embedding_dimensions: int = Field(default=768, ge=1) + key_prefix: str = Field(default="adk:memory") + index_name: str = Field(default="adk_memory_idx") + distance_metric: str = Field(default="COSINE") + ttl_seconds: Optional[int] = Field(default=None, ge=1) + + @field_validator("distance_metric") + @classmethod + def _validate_distance_metric(cls, v): + """Validate distance_metric is one of the allowed values.""" + allowed = {"COSINE", "L2", "IP"} + if v.upper() not in allowed: + raise ValueError(f"distance_metric must be one of {allowed}, got '{v}'") + return v.upper() + + +class ValkeyMemoryService(BaseMemoryService): + """Memory service using Valkey Search module with vector similarity. + + Uses valkey-glide client for communication with Valkey server and the + Valkey Search module for vector-based semantic search over stored + memories. This provides functionality analogous to + VertexAiRagMemoryService but backed by Valkey infrastructure. + + Memories are stored as Valkey Hash keys with fields: content, author, + timestamp, session_id, event_id, app_name, user_id, created_at, and + an embedding vector field. A vector search index (HNSW) is created + for approximate nearest neighbor retrieval, with TAG fields for + app_name and user_id to enable scoped queries. + + Example usage: + + from glide import GlideClientConfiguration, NodeAddress, GlideClient + + # IMPORTANT: Set client_name for observability in CLIENT LIST, + # monitoring dashboards, and CloudWatch metrics. + config = GlideClientConfiguration( + addresses=[NodeAddress(host="localhost", port=6379)], + client_name="adk_memory_client", + ) + client = await GlideClient.create(config) + + async def my_embed_fn(texts: list[str]) -> list[list[float]]: + # Your embedding logic here (OpenAI, Gemini, etc.) + ... + + service = ValkeyMemoryService( + client=client, + embedding_function=my_embed_fn, + ) + await service.create_index() + + """ + + def __init__( + self, + client: Union["GlideClient", "GlideClusterClient"], + embedding_function: EmbeddingFunction, + config: Optional[ValkeyMemoryServiceConfig] = None, + ): + """Initializes the Valkey memory service. + + Args: + client: A connected valkey-glide GlideClient or + GlideClusterClient instance. The caller is responsible + for creating and managing the client lifecycle. + embedding_function: An async callable that takes a list of + text strings and returns a list of embedding vectors + (list of floats). Users provide their own embedding + model (e.g., OpenAI, Google Gemini, sentence-transformers). + config: Optional ValkeyMemoryServiceConfig instance. + If None, uses defaults. + """ + if client is None: + raise ValueError( + "client is required. Provide a connected valkey-glide " + "GlideClient or GlideClusterClient instance." + ) + if embedding_function is None: + raise ValueError( + "embedding_function is required. Provide an async callable " + "that takes list[str] and returns list[list[float]]." + ) + self._client = client + self._embedding_function = embedding_function + self._config = config or ValkeyMemoryServiceConfig() + self._index_created = False + self._index_lock = asyncio.Lock() + + async def create_index(self): + """Create the Valkey Search index if it does not already exist. + + Creates a vector search index (HNSW) with: + - embedding: VECTOR field (HNSW, FLOAT32) for similarity search + - content: TEXT field for optional full-text filtering + - app_name: TAG field for filtering by application + - user_id: TAG field for filtering by user + - author: TAG field for filtering by author + + This method is idempotent — if the index already exists, it + will log a debug message and return without error. + """ + from glide import DataType + from glide import DistanceMetricType + from glide import ft + from glide import FtCreateOptions + from glide import TagField + from glide import TextField + from glide import VectorAlgorithm + from glide import VectorField + from glide import VectorFieldAttributesHnsw + from glide import VectorType + + distance_map = { + "COSINE": DistanceMetricType.COSINE, + "L2": DistanceMetricType.L2, + "IP": DistanceMetricType.IP, + } + distance_metric = distance_map.get( + self._config.distance_metric.upper(), DistanceMetricType.COSINE + ) + + schema = [ + VectorField( + "embedding", + algorithm=VectorAlgorithm.HNSW, + attributes=VectorFieldAttributesHnsw( + dimensions=self._config.embedding_dimensions, + distance_metric=distance_metric, + type=VectorType.FLOAT32, + ), + ), + TextField("content"), + TagField("app_name"), + TagField("user_id"), + TagField("author"), + ] + + options = FtCreateOptions( + data_type=DataType.HASH, + prefixes=[f"{self._config.key_prefix}:"], + ) + + try: + await ft.create( + self._client, + self._config.index_name, + schema, + options, + ) + self._index_created = True + logger.info("Created search index: %s", self._config.index_name) + except Exception as e: + error_msg = str(e).lower() + if "index already exists" in error_msg or "exists" in error_msg: + self._index_created = True + logger.debug( + "Search index already exists: %s", + self._config.index_name, + ) + else: + raise + + def _memory_hash_key(self) -> str: + """Generate a unique Valkey hash key for a memory entry.""" + unique_id = uuid.uuid4().hex[:12] + return f"{self._config.key_prefix}:{unique_id}" + + async def _ensure_index(self): + """Ensure the search index exists, using double-check locking.""" + if self._index_created: + return + async with self._index_lock: + if not self._index_created: + await self.create_index() + + @staticmethod + def _vector_to_bytes(vector: list[float]) -> bytes: + """Convert a list of floats to a binary blob for Valkey storage.""" + return struct.pack(f"<{len(vector)}f", *vector) + + @override + async def add_session_to_memory(self, session: Session): + """Add a session's events to Valkey memory storage. + + Extracts text from session events, generates embeddings using the + configured embedding function, and stores each event as a Valkey + Hash with the embedding vector for later similarity search. + """ + await self._ensure_index() + await self._ingest_events( + events=session.events, + app_name=session.app_name, + user_id=session.user_id, + session_id=session.id, + ) + + @override + async def add_events_to_memory( + self, + *, + app_name: str, + user_id: str, + events: Sequence, + session_id: str | None = None, + custom_metadata=None, + ) -> None: + """Adds an incremental list of events to memory. + + Generates embeddings and stores each event with text content. + This is useful for persisting only a subset of events (e.g., + the latest turn) without re-ingesting the full session. + + Args: + app_name: The application name for memory scope. + user_id: The user ID for memory scope. + events: The events to add to memory. + session_id: Optional session ID for partitioning. + custom_metadata: Optional metadata (unused currently). + """ + await self._ensure_index() + await self._ingest_events( + events=events, + app_name=app_name, + user_id=user_id, + session_id=session_id or "", + ) + + async def _ingest_events( + self, + events: Sequence, + app_name: str, + user_id: str, + session_id: str, + ) -> None: + """Shared ingestion logic for add_session_to_memory and add_events_to_memory. + + Extracts text from events, generates embeddings in batch, and + stores each event as a Valkey Hash using pipelined Batch commands + for efficiency. + + Args: + events: The events to ingest. + app_name: The application name for memory scope. + user_id: The user ID for memory scope. + session_id: The session ID for partitioning. + """ + from glide import Batch + + # Collect texts and their corresponding events + texts = [] + valid_events = [] + for event in events: + content_text = extract_text_from_event(event) + if content_text: + texts.append(content_text) + valid_events.append(event) + + if not texts: + logger.debug("No text events to ingest") + return + + # Generate embeddings for all texts in one batch + try: + embeddings = await self._embedding_function(texts) + except Exception as e: + logger.error("Failed to generate embeddings: %s", e) + return + + if len(embeddings) != len(texts): + logger.error( + "Embedding function returned %d vectors for %d texts", + len(embeddings), + len(texts), + ) + return + + # Build a Batch to pipeline all hset + expire calls + batch = Batch(is_atomic=False) + hash_keys = [] + for event, content_text, embedding in zip(valid_events, texts, embeddings): + hash_key = self._memory_hash_key() + hash_keys.append(hash_key) + field_values = { + "content": content_text, + "author": event.author or "", + "timestamp": str(event.timestamp) if event.timestamp else "0", + "session_id": session_id, + "event_id": event.id or "", + "app_name": app_name, + "user_id": user_id, + "created_at": str(time.time()), + "embedding": self._vector_to_bytes(embedding), + } + + batch.hset(hash_key, field_values) + if self._config.ttl_seconds is not None: + batch.expire(hash_key, self._config.ttl_seconds) + + try: + await self._client.exec(batch, raise_on_error=True) + logger.info("Added %d memories via batch pipeline", len(hash_keys)) + except Exception as e: + logger.error("Failed to execute batch pipeline: %s", e) + raise RuntimeError(f"Memory ingestion failed: {e}") from e + + # Characters that must be escaped in Valkey Search TAG field values. + # Includes '?' which is a single-character wildcard glob in TAG queries. + _TAG_SPECIAL_CHARS = set(r',.<>{}[]"' + r"':;!@#$%^&*()-+=~|/\\ ?") + + @staticmethod + def _escape_tag_value(value: str) -> str: + """Escape special characters for Valkey Search TAG field queries. + + Per the Valkey Search query syntax, TAG values must have + metacharacters escaped with a backslash. + """ + escaped = [] + for ch in value: + if ch in ValkeyMemoryService._TAG_SPECIAL_CHARS: + escaped.append(f"\\{ch}") + else: + escaped.append(ch) + return "".join(escaped) + + def _build_knn_query(self, app_name: str, user_id: str, top_k: int) -> str: + """Build a KNN search query with TAG pre-filters. + + Args: + app_name: Application name filter. + user_id: User ID filter. + top_k: Number of nearest neighbors to retrieve. + + Returns: + A Valkey Search KNN query string. + """ + escaped_app = self._escape_tag_value(app_name) + escaped_user = self._escape_tag_value(user_id) + + # KNN query with pre-filter: filter first, then KNN on results + return ( + f"(@app_name:{{{escaped_app}}} " + f"@user_id:{{{escaped_user}}})" + f"=>[KNN {top_k} @embedding $query_vec]" + ) + + @override + async def search_memory( + self, *, app_name: str, user_id: str, query: str + ) -> SearchMemoryResponse: + """Search for memories using vector similarity (KNN). + + Generates an embedding for the query text, then performs a KNN + search using FT.SEARCH with pre-filtering by app_name and + user_id. Results are ranked by vector distance (lower = more + similar for COSINE). + + Args: + app_name: The application name to scope the search. + user_id: The user ID to scope the search. + query: The search query string. + + Returns: + SearchMemoryResponse containing matching MemoryEntry objects, + ordered by similarity. + """ + from glide import ft + from glide import FtSearchOptions + + await self._ensure_index() + + # Generate embedding for the query + try: + query_embeddings = await self._embedding_function([query]) + query_embedding = query_embeddings[0] + except Exception as e: + logger.error("Failed to generate query embedding: %s", e) + return SearchMemoryResponse(memories=[]) + + query_vec_bytes = self._vector_to_bytes(query_embedding) + + try: + search_query = self._build_knn_query( + app_name, user_id, self._config.similarity_top_k + ) + options = FtSearchOptions( + params={"query_vec": query_vec_bytes}, + ) + + result = await ft.search( + self._client, + self._config.index_name, + search_query, + options, + ) + + if not result or len(result) < 2: + return SearchMemoryResponse(memories=[]) + + doc_count = result[0] + if doc_count == 0: + return SearchMemoryResponse(memories=[]) + + memories = [] + doc_map = result[1] + + for doc_id, fields in doc_map.items(): + try: + # Check distance threshold if configured + if self._config.vector_distance_threshold is not None: + score_raw = self._decode(fields.get(b"__embedding_score", b"")) + if score_raw: + distance = float(score_raw) + if distance > self._config.vector_distance_threshold: + continue + + content_text = self._decode(fields.get(b"content", b"")) + if not content_text: + continue + + author = self._decode(fields.get(b"author", b"")) or None + timestamp_raw = self._decode(fields.get(b"timestamp", b"0")) + if timestamp_raw and timestamp_raw != "0": + try: + timestamp = str(int(float(timestamp_raw))) + except (ValueError, TypeError): + timestamp = timestamp_raw + else: + timestamp = None + + content = types.Content(parts=[types.Part(text=content_text)]) + entry = MemoryEntry( + content=content, + author=author, + timestamp=timestamp, + ) + memories.append(entry) + except Exception as e: + logger.debug("Failed to parse search result: %s", e) + continue + + logger.info("Found %d memories for query: '%s'", len(memories), query) + return SearchMemoryResponse(memories=memories) + + except Exception as e: + logger.error("Failed to search memories: %s", e) + return SearchMemoryResponse(memories=[]) + + @staticmethod + def _decode(value) -> str: + """Decode bytes to string if needed.""" + if isinstance(value, bytes): + return value.decode("utf-8") + return str(value) if value is not None else "" + + async def close(self): + """Close the memory service. + + Note: This does NOT close the underlying Valkey client, as + the client lifecycle is managed by the caller. + """ + pass diff --git a/tests/integration/test_valkey_memory_service_integration.py b/tests/integration/test_valkey_memory_service_integration.py new file mode 100644 index 00000000..2ac5e71f --- /dev/null +++ b/tests/integration/test_valkey_memory_service_integration.py @@ -0,0 +1,654 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for ValkeyMemoryService with vector similarity search. + +Requires a running Valkey instance with the Search module loaded. +Set VALKEY_HOST and VALKEY_PORT environment variables if not using +defaults (localhost:6379). + +The valkey-bundle image includes the Search module with vector support. +Run with: + + podman run -d --name valkey-test -p 6379:6379 valkey/valkey-bundle:9.1 + pytest tests/integration/test_valkey_memory_service_integration.py -v +""" + +from __future__ import annotations + +import asyncio +import math +import os +import uuid + +from google.adk.events.event import Event +from google.adk.sessions.session import Session +from google.genai import types +import pytest + +from google.adk_community.memory.valkey_memory_service import ValkeyMemoryService +from google.adk_community.memory.valkey_memory_service import ValkeyMemoryServiceConfig + +VALKEY_HOST = os.environ.get("VALKEY_HOST", "localhost") +VALKEY_PORT = int(os.environ.get("VALKEY_PORT", "6379")) + +# Simple deterministic embedding function for testing. +# Maps text to a 32-dim vector based on character frequencies. +EMBED_DIM = 32 + + +async def _test_embedding_function( + texts: list[str], +) -> list[list[float]]: + """Deterministic embedding function for testing. + + Generates a 32-dimensional vector based on character frequency + distribution. This gives semantically similar texts somewhat + similar vectors (texts with similar character distributions). + """ + embeddings = [] + for text in texts: + text_lower = text.lower() + vec = [0.0] * EMBED_DIM + for ch in text_lower: + idx = ord(ch) % EMBED_DIM + vec[idx] += 1.0 + # Normalize to unit vector + magnitude = math.sqrt(sum(x * x for x in vec)) or 1.0 + vec = [x / magnitude for x in vec] + embeddings.append(vec) + return embeddings + + +def _requires_valkey(): + """Check if Valkey is available, skip if not.""" + try: + import glide # noqa: F401 + except ImportError: + pytest.skip("valkey-glide not installed") + + +@pytest.fixture +async def valkey_client(): + """Create a connected valkey-glide client.""" + _requires_valkey() + from glide import GlideClient + from glide import GlideClientConfiguration + from glide import NodeAddress + + config = GlideClientConfiguration( + addresses=[NodeAddress(host=VALKEY_HOST, port=VALKEY_PORT)], + client_name="adk_memory_integration_test_client", + ) + client = await GlideClient.create(config) + yield client + await client.close() + + +@pytest.fixture +async def memory_service(valkey_client): + """Create ValkeyMemoryService with a unique prefix for test isolation.""" + from glide import ft + + test_prefix = f"test:memory:{uuid.uuid4().hex[:8]}" + index_name = f"test_idx_{uuid.uuid4().hex[:8]}" + config = ValkeyMemoryServiceConfig( + key_prefix=test_prefix, + index_name=index_name, + similarity_top_k=10, + embedding_dimensions=EMBED_DIM, + ) + service = ValkeyMemoryService( + client=valkey_client, + embedding_function=_test_embedding_function, + config=config, + ) + await service.create_index() + await asyncio.sleep(0.1) + + yield service + + # Cleanup + try: + await ft.dropindex(valkey_client, index_name) + except Exception: + pass + try: + keys = await valkey_client.custom_command(["KEYS", f"{test_prefix}:*"]) + if keys: + for key in keys: + key_str = key.decode() if isinstance(key, bytes) else key + await valkey_client.custom_command(["DEL", key_str]) + except Exception: + pass + + +def _make_session(app_name: str, user_id: str) -> Session: + """Create a test session with events.""" + return Session( + app_name=app_name, + user_id=user_id, + id=f"session-{uuid.uuid4().hex[:8]}", + last_update_time=1000, + events=[ + Event( + id="event-1", + invocation_id="inv-1", + author="user", + timestamp=12345, + content=types.Content( + parts=[types.Part(text="I enjoy learning Python.")] + ), + ), + Event( + id="event-2", + invocation_id="inv-2", + author="model", + timestamp=12346, + content=types.Content( + parts=[ + types.Part( + text="Python is versatile and beginner-friendly." + ) + ] + ), + ), + Event( + id="event-3", + invocation_id="inv-3", + author="user", + timestamp=12347, + content=types.Content( + parts=[ + types.Part( + text="What about Rust for systems programming?" + ) + ] + ), + ), + ], + ) + + +@pytest.mark.asyncio +class TestValkeyMemoryServiceIntegration: + """Integration tests with vector similarity search.""" + + async def test_add_and_search_memories(self, memory_service): + """Test adding a session and searching with vector similarity.""" + session = _make_session("test-app", "user-1") + await memory_service.add_session_to_memory(session) + await asyncio.sleep(0.5) + + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="Python programming language", + ) + + assert len(result.memories) >= 1 + texts = [m.content.parts[0].text for m in result.memories] + assert any("Python" in t for t in texts) + + async def test_search_returns_results_ranked_by_similarity( + self, memory_service + ): + """Test that results are returned by vector similarity.""" + session = _make_session("test-app", "user-1") + await memory_service.add_session_to_memory(session) + await asyncio.sleep(0.5) + + # Query similar to "Python" content should return Python memories + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="I enjoy learning Python programming", + ) + + assert len(result.memories) >= 1 + # The most similar result should be about Python + top_text = result.memories[0].content.parts[0].text + assert "Python" in top_text or "python" in top_text.lower() + + async def test_user_isolation(self, memory_service): + """Test that memories are isolated between users.""" + session1 = _make_session("test-app", "user-1") + session2 = Session( + app_name="test-app", + user_id="user-2", + id="session-other", + last_update_time=1000, + events=[ + Event( + id="event-other", + invocation_id="inv-other", + author="user", + timestamp=12345, + content=types.Content( + parts=[types.Part(text="I prefer Java over everything.")] + ), + ), + ], + ) + + await memory_service.add_session_to_memory(session1) + await memory_service.add_session_to_memory(session2) + await asyncio.sleep(0.5) + + # user-1 should not see user-2's Java memory + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="Java programming", + ) + texts = [m.content.parts[0].text for m in result.memories] + assert not any("Java" in t for t in texts) + + # user-2 should see their own Java memory + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-2", + query="Java programming", + ) + assert len(result.memories) == 1 + assert "Java" in result.memories[0].content.parts[0].text + + async def test_app_isolation(self, memory_service): + """Test that memories are isolated between applications.""" + session1 = Session( + app_name="app-one", + user_id="user-1", + id="session-app1", + last_update_time=1000, + events=[ + Event( + id="event-app1", + invocation_id="inv-app1", + author="user", + timestamp=12345, + content=types.Content( + parts=[types.Part(text="Kubernetes orchestration tips.")] + ), + ), + ], + ) + session2 = Session( + app_name="app-two", + user_id="user-1", + id="session-app2", + last_update_time=1000, + events=[ + Event( + id="event-app2", + invocation_id="inv-app2", + author="user", + timestamp=12345, + content=types.Content( + parts=[types.Part(text="Docker container best practices.")] + ), + ), + ], + ) + + await memory_service.add_session_to_memory(session1) + await memory_service.add_session_to_memory(session2) + await asyncio.sleep(0.5) + + # app-one should only see its own memories + result = await memory_service.search_memory( + app_name="app-one", + user_id="user-1", + query="Kubernetes orchestration", + ) + assert len(result.memories) == 1 + + # app-two should only see its own Docker memory, not Kubernetes + result = await memory_service.search_memory( + app_name="app-two", + user_id="user-1", + query="Kubernetes orchestration", + ) + # KNN returns nearest neighbor from the filtered set (app-two only). + # The result should NOT contain app-one's Kubernetes content. + for mem in result.memories: + assert "Kubernetes" not in mem.content.parts[0].text + + async def test_multiple_sessions_accumulate(self, memory_service): + """Test that multiple sessions accumulate memories.""" + session1 = _make_session("test-app", "user-1") + session2 = Session( + app_name="test-app", + user_id="user-1", + id="session-2", + last_update_time=2000, + events=[ + Event( + id="event-extra", + invocation_id="inv-extra", + author="user", + timestamp=22345, + content=types.Content( + parts=[types.Part(text="Python web frameworks are useful.")] + ), + ), + ], + ) + + await memory_service.add_session_to_memory(session1) + await memory_service.add_session_to_memory(session2) + await asyncio.sleep(0.5) + + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="Python programming", + ) + + # Should find memories from both sessions + assert len(result.memories) >= 3 + + async def test_search_empty_store(self, memory_service): + """Test searching when no memories have been added.""" + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="anything at all", + ) + assert len(result.memories) == 0 + + async def test_similarity_top_k_limit(self, valkey_client): + """Test that similarity_top_k limits the number of results.""" + from glide import ft + + test_prefix = f"test:topk:{uuid.uuid4().hex[:8]}" + index_name = f"test_topk_idx_{uuid.uuid4().hex[:8]}" + config = ValkeyMemoryServiceConfig( + key_prefix=test_prefix, + index_name=index_name, + similarity_top_k=3, + embedding_dimensions=EMBED_DIM, + ) + service = ValkeyMemoryService( + client=valkey_client, + embedding_function=_test_embedding_function, + config=config, + ) + await service.create_index() + await asyncio.sleep(0.1) + + events = [ + Event( + id=f"event-{i}", + invocation_id=f"inv-{i}", + author="user", + timestamp=12345 + i, + content=types.Content( + parts=[types.Part(text=f"Python tip number {i} is great.")] + ), + ) + for i in range(6) + ] + session = Session( + app_name="test-app", + user_id="user-1", + id="session-topk", + last_update_time=1000, + events=events, + ) + + await service.add_session_to_memory(session) + await asyncio.sleep(0.5) + + result = await service.search_memory( + app_name="test-app", + user_id="user-1", + query="Python tips", + ) + + assert len(result.memories) <= 3 + assert len(result.memories) >= 1 + + # Cleanup + try: + await ft.dropindex(valkey_client, index_name) + except Exception: + pass + try: + keys = await valkey_client.custom_command(["KEYS", f"{test_prefix}:*"]) + if keys: + for key in keys: + key_str = key.decode() if isinstance(key, bytes) else key + await valkey_client.custom_command(["DEL", key_str]) + except Exception: + pass + + async def test_events_without_text_are_filtered(self, memory_service): + """Test that function_call and empty events are not stored.""" + session = Session( + app_name="test-app", + user_id="user-1", + id="session-filter", + last_update_time=1000, + events=[ + Event( + id="event-func", + invocation_id="inv-func", + author="agent", + timestamp=12345, + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall(name="search_tool") + ) + ] + ), + ), + Event( + id="event-empty", + invocation_id="inv-empty", + author="user", + timestamp=12346, + ), + Event( + id="event-text", + invocation_id="inv-text", + author="user", + timestamp=12347, + content=types.Content( + parts=[types.Part(text="This is valid content.")] + ), + ), + ], + ) + + await memory_service.add_session_to_memory(session) + await asyncio.sleep(0.5) + + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="valid content text", + ) + assert len(result.memories) == 1 + + async def test_memory_entry_metadata(self, memory_service): + """Test that returned MemoryEntry has correct metadata.""" + session = Session( + app_name="test-app", + user_id="user-1", + id="session-meta", + last_update_time=1000, + events=[ + Event( + id="event-meta", + invocation_id="inv-meta", + author="user", + timestamp=99999, + content=types.Content( + parts=[types.Part(text="Metadata verification test.")] + ), + ), + ], + ) + + await memory_service.add_session_to_memory(session) + await asyncio.sleep(0.5) + + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="Metadata verification test", + ) + + assert len(result.memories) == 1 + entry = result.memories[0] + assert entry.content.parts[0].text == "Metadata verification test." + assert entry.author == "user" + assert entry.timestamp == "99999" + + async def test_create_index_idempotent(self, valkey_client): + """Test that calling create_index multiple times is safe.""" + from glide import ft + + index_name = f"test_idem_idx_{uuid.uuid4().hex[:8]}" + test_prefix = f"test:idem:{uuid.uuid4().hex[:8]}" + config = ValkeyMemoryServiceConfig( + key_prefix=test_prefix, + index_name=index_name, + embedding_dimensions=EMBED_DIM, + ) + service = ValkeyMemoryService( + client=valkey_client, + embedding_function=_test_embedding_function, + config=config, + ) + + await service.create_index() + assert service._index_created is True + + # Second call should not raise + await service.create_index() + assert service._index_created is True + + # Cleanup + try: + await ft.dropindex(valkey_client, index_name) + except Exception: + pass + + async def test_add_events_to_memory(self, memory_service): + """Test incremental event ingestion via add_events_to_memory.""" + events = [ + Event( + id="event-inc-1", + invocation_id="inv-inc-1", + author="user", + timestamp=50001, + content=types.Content( + parts=[types.Part(text="Incremental memory about Golang.")] + ), + ), + Event( + id="event-inc-2", + invocation_id="inv-inc-2", + author="model", + timestamp=50002, + content=types.Content( + parts=[types.Part(text="Go is great for concurrency.")] + ), + ), + ] + + await memory_service.add_events_to_memory( + app_name="test-app", + user_id="user-1", + events=events, + session_id="session-incremental", + ) + await asyncio.sleep(0.5) + + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="Golang concurrency", + ) + assert len(result.memories) >= 1 + texts = [m.content.parts[0].text for m in result.memories] + assert any("Go" in t for t in texts) + + async def test_vector_distance_threshold(self, valkey_client): + """Test that vector_distance_threshold filters distant results.""" + from glide import ft + + test_prefix = f"test:thresh:{uuid.uuid4().hex[:8]}" + index_name = f"test_thresh_idx_{uuid.uuid4().hex[:8]}" + config = ValkeyMemoryServiceConfig( + key_prefix=test_prefix, + index_name=index_name, + similarity_top_k=10, + embedding_dimensions=EMBED_DIM, + vector_distance_threshold=0.01, # Very strict threshold + ) + service = ValkeyMemoryService( + client=valkey_client, + embedding_function=_test_embedding_function, + config=config, + ) + await service.create_index() + await asyncio.sleep(0.1) + + session = Session( + app_name="test-app", + user_id="user-1", + id="session-thresh", + last_update_time=1000, + events=[ + Event( + id="event-thresh", + invocation_id="inv-thresh", + author="user", + timestamp=12345, + content=types.Content( + parts=[types.Part(text="Completely unrelated topic XYZ.")] + ), + ), + ], + ) + + await service.add_session_to_memory(session) + await asyncio.sleep(0.5) + + # Search for something very different — should be filtered by threshold + result = await service.search_memory( + app_name="test-app", + user_id="user-1", + query="AAAAAAA BBBBBBB CCCCCCC", + ) + # With strict threshold, dissimilar results should be filtered + # (This depends on the embedding function producing distant vectors) + assert len(result.memories) <= 1 + + # Cleanup + try: + await ft.dropindex(valkey_client, index_name) + except Exception: + pass + try: + keys = await valkey_client.custom_command(["KEYS", f"{test_prefix}:*"]) + if keys: + for key in keys: + key_str = key.decode() if isinstance(key, bytes) else key + await valkey_client.custom_command(["DEL", key_str]) + except Exception: + pass diff --git a/tests/unittests/memory/test_valkey_memory_service.py b/tests/unittests/memory/test_valkey_memory_service.py new file mode 100644 index 00000000..92bfb4e2 --- /dev/null +++ b/tests/unittests/memory/test_valkey_memory_service.py @@ -0,0 +1,780 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock +from unittest.mock import patch + +from google.adk.events.event import Event +from google.adk.sessions.session import Session +from google.genai import types +import pytest + +from google.adk_community.memory.valkey_memory_service import ValkeyMemoryService +from google.adk_community.memory.valkey_memory_service import ValkeyMemoryServiceConfig + +MOCK_APP_NAME = "test-app" +MOCK_USER_ID = "test-user" +MOCK_SESSION_ID = "session-1" + + +async def _mock_embed_fn(texts: list[str]) -> list[list[float]]: + """Simple mock embedding function returning fixed-dim vectors.""" + return [[0.1] * 768 for _ in texts] + + +MOCK_SESSION = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id=MOCK_SESSION_ID, + last_update_time=1000, + events=[ + Event( + id="event-1", + invocation_id="inv-1", + author="user", + timestamp=12345, + content=types.Content( + parts=[types.Part(text="Hello, I like Python.")] + ), + ), + Event( + id="event-2", + invocation_id="inv-2", + author="model", + timestamp=12346, + content=types.Content( + parts=[ + types.Part(text="Python is a great programming language.") + ] + ), + ), + # Empty event, should be ignored + Event( + id="event-3", + invocation_id="inv-3", + author="user", + timestamp=12347, + ), + # Function call event, should be ignored + Event( + id="event-4", + invocation_id="inv-4", + author="agent", + timestamp=12348, + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall(name="test_function") + ) + ] + ), + ), + ], +) + +MOCK_SESSION_WITH_EMPTY_EVENTS = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id=MOCK_SESSION_ID, + last_update_time=1000, +) + + +@pytest.fixture +def mock_valkey_client(): + """Mock valkey-glide client for testing.""" + client = AsyncMock() + client.hset = AsyncMock(return_value=1) + client.expire = AsyncMock(return_value=True) + client.exec = AsyncMock(return_value=[]) + return client + + +@pytest.fixture +def memory_service(mock_valkey_client): + """Create ValkeyMemoryService instance for testing.""" + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + ) + service._index_created = True + return service + + +@pytest.fixture +def memory_service_with_config(mock_valkey_client): + """Create ValkeyMemoryService with custom config.""" + config = ValkeyMemoryServiceConfig( + similarity_top_k=5, + key_prefix="custom:mem", + index_name="custom_idx", + ttl_seconds=3600, + embedding_dimensions=768, + vector_distance_threshold=0.5, + ) + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + config=config, + ) + service._index_created = True + return service + + +class TestValkeyMemoryServiceConfig: + """Tests for ValkeyMemoryServiceConfig.""" + + def test_default_config(self): + """Test default configuration values.""" + config = ValkeyMemoryServiceConfig() + assert config.similarity_top_k == 10 + assert config.key_prefix == "adk:memory" + assert config.index_name == "adk_memory_idx" + assert config.ttl_seconds is None + assert config.embedding_dimensions == 768 + assert config.distance_metric == "COSINE" + assert config.vector_distance_threshold is None + + def test_custom_config(self): + """Test custom configuration values.""" + config = ValkeyMemoryServiceConfig( + similarity_top_k=20, + key_prefix="my:prefix", + index_name="my_index", + ttl_seconds=7200, + embedding_dimensions=1536, + distance_metric="L2", + vector_distance_threshold=0.8, + ) + assert config.similarity_top_k == 20 + assert config.key_prefix == "my:prefix" + assert config.index_name == "my_index" + assert config.ttl_seconds == 7200 + assert config.embedding_dimensions == 1536 + assert config.distance_metric == "L2" + assert config.vector_distance_threshold == 0.8 + + def test_config_validation_top_k(self): + """Test similarity_top_k validation.""" + with pytest.raises(Exception): + ValkeyMemoryServiceConfig(similarity_top_k=0) + + with pytest.raises(Exception): + ValkeyMemoryServiceConfig(similarity_top_k=1001) + + def test_config_validation_distance_metric(self): + """Test distance_metric validation rejects invalid values.""" + with pytest.raises(Exception): + ValkeyMemoryServiceConfig(distance_metric="HAMMING") + + with pytest.raises(Exception): + ValkeyMemoryServiceConfig(distance_metric="invalid") + + def test_config_distance_metric_case_insensitive(self): + """Test that distance_metric is case-insensitive and normalized.""" + config = ValkeyMemoryServiceConfig(distance_metric="cosine") + assert config.distance_metric == "COSINE" + + config = ValkeyMemoryServiceConfig(distance_metric="l2") + assert config.distance_metric == "L2" + + config = ValkeyMemoryServiceConfig(distance_metric="ip") + assert config.distance_metric == "IP" + + +class TestValkeyMemoryServiceInit: + """Tests for ValkeyMemoryService initialization.""" + + def test_client_required(self): + """Test that client is required.""" + with pytest.raises(ValueError, match="client is required"): + ValkeyMemoryService(client=None, embedding_function=_mock_embed_fn) + + def test_embedding_function_required(self, mock_valkey_client): + """Test that embedding_function is required.""" + with pytest.raises(ValueError, match="embedding_function is required"): + ValkeyMemoryService(client=mock_valkey_client, embedding_function=None) + + def test_init_with_defaults(self, mock_valkey_client): + """Test initialization with default config.""" + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + ) + assert service._client is mock_valkey_client + assert service._config.similarity_top_k == 10 + assert service._index_created is False + + def test_init_with_config(self, mock_valkey_client): + """Test initialization with custom config.""" + config = ValkeyMemoryServiceConfig(similarity_top_k=5) + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + config=config, + ) + assert service._config.similarity_top_k == 5 + + +class TestValkeyMemoryServiceCreateIndex: + """Tests for create_index.""" + + @pytest.mark.asyncio + async def test_create_index_success(self, mock_valkey_client): + """Test successful index creation.""" + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + ) + + with patch("glide.ft.create", new_callable=AsyncMock) as mock_create: + mock_create.return_value = "OK" + await service.create_index() + + assert service._index_created is True + mock_create.assert_called_once() + + @pytest.mark.asyncio + async def test_create_index_already_exists(self, mock_valkey_client): + """Test that existing index is handled gracefully.""" + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + ) + + with patch("glide.ft.create", new_callable=AsyncMock) as mock_create: + mock_create.side_effect = Exception("Index already exists") + await service.create_index() + assert service._index_created is True + + @pytest.mark.asyncio + async def test_create_index_unexpected_error(self, mock_valkey_client): + """Test that unexpected errors are raised.""" + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + ) + + with patch("glide.ft.create", new_callable=AsyncMock) as mock_create: + mock_create.side_effect = Exception("Connection refused") + with pytest.raises(Exception, match="Connection refused"): + await service.create_index() + assert service._index_created is False + + +class TestValkeyMemoryServiceAddSession: + """Tests for add_session_to_memory.""" + + @pytest.mark.asyncio + async def test_add_session_success(self, memory_service, mock_valkey_client): + """Test successful addition of session memories.""" + with patch("glide.Batch") as MockBatch: + mock_batch_instance = MockBatch.return_value + mock_batch_instance.hset = lambda *args, **kwargs: None + mock_batch_instance.expire = lambda *args, **kwargs: None + + await memory_service.add_session_to_memory(MOCK_SESSION) + + # Should call exec once with the batch + mock_valkey_client.exec.assert_called_once_with( + mock_batch_instance, raise_on_error=True + ) + # Batch should be created as non-atomic + MockBatch.assert_called_once_with(is_atomic=False) + + @pytest.mark.asyncio + async def test_add_session_filters_empty_events( + self, memory_service, mock_valkey_client + ): + """Test that events without text content are filtered out.""" + await memory_service.add_session_to_memory(MOCK_SESSION_WITH_EMPTY_EVENTS) + mock_valkey_client.exec.assert_not_called() + + @pytest.mark.asyncio + async def test_add_session_with_ttl( + self, memory_service_with_config, mock_valkey_client + ): + """Test that TTL is set when configured.""" + with patch("glide.Batch") as MockBatch: + hset_calls = [] + expire_calls = [] + mock_batch_instance = MockBatch.return_value + mock_batch_instance.hset = lambda *args, **kwargs: hset_calls.append(args) + mock_batch_instance.expire = lambda *args, **kwargs: expire_calls.append( + args + ) + + await memory_service_with_config.add_session_to_memory(MOCK_SESSION) + + # 2 valid events -> 2 hset + 2 expire calls on the batch + assert len(hset_calls) == 2 + assert len(expire_calls) == 2 + # TTL should be 3600 + assert expire_calls[0][1] == 3600 + + @pytest.mark.asyncio + async def test_add_session_no_ttl_by_default( + self, memory_service, mock_valkey_client + ): + """Test that no TTL is set when not configured.""" + with patch("glide.Batch") as MockBatch: + expire_calls = [] + mock_batch_instance = MockBatch.return_value + mock_batch_instance.hset = lambda *args, **kwargs: None + mock_batch_instance.expire = lambda *args, **kwargs: expire_calls.append( + args + ) + + await memory_service.add_session_to_memory(MOCK_SESSION) + + assert len(expire_calls) == 0 + + @pytest.mark.asyncio + async def test_add_session_embedding_error(self, mock_valkey_client): + """Test handling of embedding function failure.""" + + async def _failing_embed(texts): + raise RuntimeError("Embedding service unavailable") + + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_failing_embed, + ) + service._index_created = True + + # Should not raise, just log error + await service.add_session_to_memory(MOCK_SESSION) + mock_valkey_client.exec.assert_not_called() + + @pytest.mark.asyncio + async def test_add_session_batch_exec_error( + self, memory_service, mock_valkey_client + ): + """Test that batch exec failure raises RuntimeError.""" + mock_valkey_client.exec.side_effect = Exception("Connection error") + + with patch("glide.Batch") as MockBatch: + mock_batch_instance = MockBatch.return_value + mock_batch_instance.hset = lambda *args, **kwargs: None + mock_batch_instance.expire = lambda *args, **kwargs: None + + with pytest.raises(RuntimeError, match="Memory ingestion failed"): + await memory_service.add_session_to_memory(MOCK_SESSION) + mock_valkey_client.exec.assert_called_once() + + @pytest.mark.asyncio + async def test_add_session_custom_key_prefix( + self, memory_service_with_config, mock_valkey_client + ): + """Test that custom key prefix is used.""" + with patch("glide.Batch") as MockBatch: + hset_calls = [] + mock_batch_instance = MockBatch.return_value + mock_batch_instance.hset = lambda *args, **kwargs: hset_calls.append(args) + mock_batch_instance.expire = lambda *args, **kwargs: None + + await memory_service_with_config.add_session_to_memory(MOCK_SESSION) + + key = hset_calls[0][0] + assert key.startswith("custom:mem:") + + @pytest.mark.asyncio + async def test_add_session_creates_index_if_needed(self, mock_valkey_client): + """Test that create_index is called if not yet created.""" + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + ) + + with ( + patch("glide.ft.create", new_callable=AsyncMock) as mock_create, + patch("glide.Batch") as MockBatch, + ): + mock_create.return_value = "OK" + mock_batch_instance = MockBatch.return_value + mock_batch_instance.hset = lambda *args, **kwargs: None + mock_batch_instance.expire = lambda *args, **kwargs: None + + await service.add_session_to_memory(MOCK_SESSION) + mock_create.assert_called_once() + assert service._index_created is True + + +class TestValkeyMemoryServiceAddEvents: + """Tests for add_events_to_memory.""" + + @pytest.mark.asyncio + async def test_add_events_success(self, memory_service, mock_valkey_client): + """Test incremental event ingestion.""" + events = [ + Event( + id="ev-1", + invocation_id="inv-1", + author="user", + timestamp=100, + content=types.Content(parts=[types.Part(text="Hello world")]), + ), + ] + + with patch("glide.Batch") as MockBatch: + hset_calls = [] + mock_batch_instance = MockBatch.return_value + mock_batch_instance.hset = lambda *args, **kwargs: hset_calls.append(args) + mock_batch_instance.expire = lambda *args, **kwargs: None + + await memory_service.add_events_to_memory( + app_name="myapp", + user_id="user1", + events=events, + session_id="sess-1", + ) + + assert len(hset_calls) == 1 + fields = hset_calls[0][1] + assert fields["content"] == "Hello world" + assert fields["app_name"] == "myapp" + assert fields["user_id"] == "user1" + assert fields["session_id"] == "sess-1" + mock_valkey_client.exec.assert_called_once() + + @pytest.mark.asyncio + async def test_add_events_filters_empty( + self, memory_service, mock_valkey_client + ): + """Test that empty events are skipped.""" + events = [ + Event(id="ev-empty", invocation_id="inv-1", author="user"), + ] + + await memory_service.add_events_to_memory( + app_name="myapp", + user_id="user1", + events=events, + ) + + mock_valkey_client.exec.assert_not_called() + + +class TestValkeyMemoryServiceSearch: + """Tests for search_memory.""" + + @pytest.mark.asyncio + async def test_search_memory_success( + self, memory_service, mock_valkey_client + ): + """Test successful memory search using KNN.""" + search_result = [ + 2, + { + b"adk:memory:abc123": { + b"content": b"I love Python programming", + b"author": b"user", + b"timestamp": b"12345", + b"__embedding_score": b"0.15", + }, + b"adk:memory:def456": { + b"content": b"Python has great libraries", + b"author": b"model", + b"timestamp": b"12346", + b"__embedding_score": b"0.25", + }, + }, + ] + + with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: + mock_search.return_value = search_result + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="Python", + ) + + assert len(result.memories) == 2 + assert ( + result.memories[0].content.parts[0].text + == "I love Python programming" + ) + assert result.memories[0].author == "user" + + @pytest.mark.asyncio + async def test_search_memory_empty_result( + self, memory_service, mock_valkey_client + ): + """Test search when no results.""" + with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: + mock_search.return_value = [0, {}] + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="anything", + ) + assert len(result.memories) == 0 + + @pytest.mark.asyncio + async def test_search_memory_uses_knn_query( + self, memory_service, mock_valkey_client + ): + """Test that search builds a KNN query with TAG filters.""" + with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: + mock_search.return_value = [0, {}] + + await memory_service.search_memory( + app_name="my-app", + user_id="user-123", + query="test query", + ) + + call_args = mock_search.call_args + query_str = call_args[0][2] + assert "my\\-app" in query_str + assert "user\\-123" in query_str + assert "KNN" in query_str + assert "@embedding" in query_str + + @pytest.mark.asyncio + async def test_search_memory_passes_query_vec( + self, memory_service, mock_valkey_client + ): + """Test that query embedding is passed as params.""" + with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: + mock_search.return_value = [0, {}] + + await memory_service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="test", + ) + + call_args = mock_search.call_args + options = call_args[0][3] + assert "query_vec" in options.params + assert isinstance(options.params["query_vec"], bytes) + + @pytest.mark.asyncio + async def test_search_memory_distance_threshold( + self, memory_service_with_config, mock_valkey_client + ): + """Test that distance threshold filters results.""" + search_result = [ + 2, + { + b"adk:memory:close": { + b"content": b"Close match", + b"author": b"user", + b"timestamp": b"1", + b"__embedding_score": b"0.3", + }, + b"adk:memory:far": { + b"content": b"Far match", + b"author": b"user", + b"timestamp": b"2", + b"__embedding_score": b"0.9", + }, + }, + ] + + with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: + mock_search.return_value = search_result + + # Threshold is 0.5, so only "Close match" (0.3) should pass + result = await memory_service_with_config.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="test", + ) + + assert len(result.memories) == 1 + assert result.memories[0].content.parts[0].text == "Close match" + + @pytest.mark.asyncio + async def test_search_memory_embedding_error(self, mock_valkey_client): + """Test graceful handling of embedding failure during search.""" + + async def _failing_embed(texts): + raise RuntimeError("Embedding service down") + + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_failing_embed, + ) + service._index_created = True + + result = await service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="test", + ) + assert len(result.memories) == 0 + + @pytest.mark.asyncio + async def test_search_memory_ft_search_error( + self, memory_service, mock_valkey_client + ): + """Test graceful handling of FT.SEARCH failure.""" + with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: + mock_search.side_effect = Exception("Connection error") + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="test", + ) + assert len(result.memories) == 0 + + +class TestValkeyMemoryServiceBuildQuery: + """Tests for _build_knn_query.""" + + def test_basic_knn_query(self, memory_service): + """Test KNN query construction.""" + query = memory_service._build_knn_query("myapp", "user1", 10) + assert "@app_name:{myapp}" in query + assert "@user_id:{user1}" in query + assert "KNN 10 @embedding" in query + + def test_hyphenated_values(self, memory_service): + """Test escaping of hyphens in TAG values.""" + query = memory_service._build_knn_query("my-app", "user-1", 5) + assert "my\\-app" in query + assert "user\\-1" in query + + def test_special_characters_dots(self, memory_service): + """Test escaping of dots in TAG values.""" + query = memory_service._build_knn_query("com.example.app", "user.1", 10) + assert "com\\.example\\.app" in query + assert "user\\.1" in query + + def test_special_characters_colons(self, memory_service): + """Test escaping of colons in TAG values.""" + query = memory_service._build_knn_query("app:v2", "user:123", 10) + assert "app\\:v2" in query + assert "user\\:123" in query + + def test_special_characters_at_sign(self, memory_service): + """Test escaping of @ sign in TAG values.""" + query = memory_service._build_knn_query("app@org", "user@domain", 10) + assert "app\\@org" in query + assert "user\\@domain" in query + + def test_special_characters_spaces(self, memory_service): + """Test escaping of spaces in TAG values.""" + query = memory_service._build_knn_query("my app", "user 1", 10) + assert "my\\ app" in query + assert "user\\ 1" in query + + +class TestValkeyMemoryServiceClose: + """Tests for close method.""" + + @pytest.mark.asyncio + async def test_close_does_not_close_client( + self, memory_service, mock_valkey_client + ): + """Test that close does not close the underlying client.""" + await memory_service.close() + mock_valkey_client.close.assert_not_called() + + +class TestValkeyMemoryServiceEnsureIndex: + """Tests for _ensure_index with asyncio.Lock.""" + + @pytest.mark.asyncio + async def test_ensure_index_skips_if_already_created( + self, mock_valkey_client + ): + """Test that _ensure_index does not call create_index if already done.""" + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + ) + service._index_created = True + + with patch("glide.ft.create", new_callable=AsyncMock) as mock_create: + await service._ensure_index() + mock_create.assert_not_called() + + @pytest.mark.asyncio + async def test_ensure_index_calls_create_index_once(self, mock_valkey_client): + """Test that concurrent calls to _ensure_index only create index once.""" + import asyncio + + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + ) + + call_count = 0 + + async def mock_create(*args, **kwargs): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.01) # Simulate async work + return "OK" + + with patch("glide.ft.create", side_effect=mock_create): + # Launch multiple concurrent ensure_index calls + await asyncio.gather( + service._ensure_index(), + service._ensure_index(), + service._ensure_index(), + ) + + # Should only call create once due to lock + assert call_count == 1 + assert service._index_created is True + + +class TestValkeyMemoryServiceEscapeTagValue: + """Tests for _escape_tag_value static method.""" + + def test_no_special_characters(self): + """Test that plain values are unchanged.""" + assert ValkeyMemoryService._escape_tag_value("myapp") == "myapp" + + def test_hyphen_escaped(self): + """Test hyphen escaping.""" + assert ValkeyMemoryService._escape_tag_value("my-app") == "my\\-app" + + def test_dot_escaped(self): + """Test dot escaping.""" + assert ( + ValkeyMemoryService._escape_tag_value("com.example.app") + == "com\\.example\\.app" + ) + + def test_colon_escaped(self): + """Test colon escaping.""" + assert ValkeyMemoryService._escape_tag_value("app:v2") == "app\\:v2" + + def test_at_sign_escaped(self): + """Test @ sign escaping.""" + assert ( + ValkeyMemoryService._escape_tag_value("user@domain") == "user\\@domain" + ) + + def test_space_escaped(self): + """Test space escaping.""" + assert ValkeyMemoryService._escape_tag_value("my app") == "my\\ app" + + def test_multiple_special_chars(self): + """Test multiple special characters in one value.""" + result = ValkeyMemoryService._escape_tag_value("a-b.c:d@e") + assert result == "a\\-b\\.c\\:d\\@e" + + def test_question_mark_escaped(self): + """Test question mark escaping (wildcard glob in TAG queries).""" + assert ValkeyMemoryService._escape_tag_value("app?") == "app\\?" + assert ValkeyMemoryService._escape_tag_value("a?b?c") == "a\\?b\\?c"