From 1a150d98c452e32d187759b61831b4696ab895ef Mon Sep 17 00:00:00 2001 From: Daria Korenieva Date: Tue, 2 Jun 2026 11:13:41 -0700 Subject: [PATCH 1/7] feat: add ValkeyMemoryService using valkey-glide 2.4.0 Implements BaseMemoryService backed by Valkey using the valkey-glide client library. Stores memories as JSON in Valkey lists keyed by app_name and user_id, with simple text-based substring search. - ValkeyMemoryServiceConfig: configurable search_top_k, key_prefix, ttl - ValkeyMemoryService: add_session_to_memory, search_memory, close - Optional dependency: valkey-glide>=2.4.0 under [valkey] extra - 23 unit tests (mocked client) - 5 integration tests (requires running Valkey instance) Ref: AEA-497 Signed-off-by: Daria Korenieva --- pyproject.toml | 3 + src/google/adk_community/memory/__init__.py | 4 + .../memory/valkey_memory_service.py | 227 ++++++++ .../test_valkey_memory_service_integration.py | 252 +++++++++ .../memory/test_valkey_memory_service.py | 496 ++++++++++++++++++ 5 files changed, 982 insertions(+) create mode 100644 src/google/adk_community/memory/valkey_memory_service.py create mode 100644 tests/integration/test_valkey_memory_service_integration.py create mode 100644 tests/unittests/memory/test_valkey_memory_service.py diff --git a/pyproject.toml b/pyproject.toml index 3bf11124..738478ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,9 @@ sdc-agents = [ "sdc-agents>=4.3.3; python_version >= '3.11'", ] spraay = ["web3>=6.0.0"] +valkey = [ + "valkey-glide>=2.4.0", +] [tool.pyink] diff --git a/src/google/adk_community/memory/__init__.py b/src/google/adk_community/memory/__init__.py index 1f3442c0..f029ec77 100644 --- a/src/google/adk_community/memory/__init__.py +++ b/src/google/adk_community/memory/__init__.py @@ -16,9 +16,13 @@ from .open_memory_service import OpenMemoryService from .open_memory_service import OpenMemoryServiceConfig +from .valkey_memory_service import ValkeyMemoryService +from .valkey_memory_service import ValkeyMemoryServiceConfig __all__ = [ "OpenMemoryService", "OpenMemoryServiceConfig", + "ValkeyMemoryService", + "ValkeyMemoryServiceConfig", ] diff --git a/src/google/adk_community/memory/valkey_memory_service.py b/src/google/adk_community/memory/valkey_memory_service.py new file mode 100644 index 00000000..c7e18560 --- /dev/null +++ b/src/google/adk_community/memory/valkey_memory_service.py @@ -0,0 +1,227 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Valkey-backed memory service for ADK using valkey-glide client.""" + +from __future__ import annotations + +import json +import logging +import time +from typing import Optional +from typing import TYPE_CHECKING + +from google.genai import types +from pydantic import BaseModel +from pydantic import Field +from typing_extensions import override + +from google.adk.memory.base_memory_service import BaseMemoryService +from google.adk.memory.base_memory_service import SearchMemoryResponse +from google.adk.memory.memory_entry import MemoryEntry + +from .utils import extract_text_from_event + +if TYPE_CHECKING: + from google.adk.sessions.session import Session + +logger = logging.getLogger('google_adk.' + __name__) + + +class ValkeyMemoryServiceConfig(BaseModel): + """Configuration for ValkeyMemoryService. + + Attributes: + search_top_k: Maximum number of memories to retrieve per search. + key_prefix: Prefix for all Valkey keys to avoid collisions. + ttl_seconds: Optional TTL for memory entries in seconds. None means + no expiration. + """ + + search_top_k: int = Field(default=10, ge=1, le=100) + key_prefix: str = Field(default="adk:memory") + ttl_seconds: Optional[int] = Field(default=None, ge=1) + + +class ValkeyMemoryService(BaseMemoryService): + """Memory service implementation using Valkey as the backend. + + Uses valkey-glide client for communication with Valkey server. + Memories are stored as JSON strings in Valkey lists, indexed by + app_name and user_id for efficient retrieval. + + Example usage: + + from glide import GlideClientConfiguration, NodeAddress, GlideClient + + config = GlideClientConfiguration( + addresses=[NodeAddress(host="localhost", port=6379)], + client_name="adk_memory_client", + ) + client = await GlideClient.create(config) + service = ValkeyMemoryService(client=client) + + """ + + def __init__( + self, + client, + config: Optional[ValkeyMemoryServiceConfig] = None, + ): + """Initializes the Valkey memory service. + + Args: + client: A connected valkey-glide GlideClient or + GlideClusterClient instance. The caller is responsible + for creating and managing the client lifecycle. + config: Optional ValkeyMemoryServiceConfig instance. + If None, uses defaults. + """ + if client is None: + raise ValueError( + "client is required. Provide a connected valkey-glide " + "GlideClient or GlideClusterClient instance." + ) + self._client = client + self._config = config or ValkeyMemoryServiceConfig() + + def _memory_list_key(self, app_name: str, user_id: str) -> str: + """Generate the Valkey key for a user's memory list.""" + return f"{self._config.key_prefix}:{app_name}:{user_id}:entries" + + def _serialize_memory( + self, event, content_text: str, session + ) -> str: + """Serialize an event into a JSON string for storage.""" + memory_data = { + "content": content_text, + "author": event.author, + "timestamp": event.timestamp, + "session_id": session.id, + "event_id": event.id, + "app_name": session.app_name, + "user_id": session.user_id, + "created_at": time.time(), + } + return json.dumps(memory_data) + + @override + async def add_session_to_memory(self, session: Session): + """Add a session's events to Valkey memory storage.""" + memories_added = 0 + list_key = self._memory_list_key(session.app_name, session.user_id) + + for event in session.events: + content_text = extract_text_from_event(event) + if not content_text: + continue + + try: + serialized = self._serialize_memory(event, content_text, session) + await self._client.rpush(list_key, [serialized]) + memories_added += 1 + logger.debug("Added memory for event %s", event.id) + except Exception as e: + logger.error( + "Failed to add memory for event %s: %s", event.id, e + ) + + if self._config.ttl_seconds and memories_added > 0: + try: + await self._client.expire(list_key, self._config.ttl_seconds) + except Exception as e: + logger.error("Failed to set TTL on key %s: %s", list_key, e) + + logger.info( + "Added %d memories from session %s", memories_added, session.id + ) + + @override + async def search_memory( + self, *, app_name: str, user_id: str, query: str + ) -> SearchMemoryResponse: + """Search for memories matching the query. + + Performs a simple text-based search over stored memories for + the given app and user. Retrieves all stored memories and + filters them by checking if the query terms appear in the + content. + + Args: + app_name: The application name to scope the search. + user_id: The user ID to scope the search. + query: The search query string. + + Returns: + SearchMemoryResponse containing matching MemoryEntry objects. + """ + list_key = self._memory_list_key(app_name, user_id) + + try: + # Retrieve all memories for this user/app + raw_memories = await self._client.lrange(list_key, 0, -1) + + if not raw_memories: + return SearchMemoryResponse(memories=[]) + + memories = [] + query_lower = query.lower() + query_terms = query_lower.split() + + for raw in raw_memories: + try: + raw_str = ( + raw.decode("utf-8") if isinstance(raw, bytes) else raw + ) + memory_data = json.loads(raw_str) + content_text = memory_data.get("content", "") + + # Simple term-matching search + content_lower = content_text.lower() + if any(term in content_lower for term in query_terms): + content = types.Content( + parts=[types.Part(text=content_text)] + ) + timestamp = memory_data.get("timestamp") + if timestamp is not None: + timestamp = str(timestamp) + entry = MemoryEntry( + content=content, + author=memory_data.get("author"), + timestamp=timestamp, + ) + memories.append(entry) + + if len(memories) >= self._config.search_top_k: + break + except (json.JSONDecodeError, KeyError) as e: + logger.debug("Failed to parse memory entry: %s", e) + continue + + logger.info( + "Found %d memories for query: '%s'", len(memories), query + ) + return SearchMemoryResponse(memories=memories) + + except Exception as e: + logger.error("Failed to search memories: %s", e) + return SearchMemoryResponse(memories=[]) + + async def close(self): + """Close the memory service. + + Note: This does NOT close the underlying Valkey client, as + the client lifecycle is managed by the caller. + """ + pass diff --git a/tests/integration/test_valkey_memory_service_integration.py b/tests/integration/test_valkey_memory_service_integration.py new file mode 100644 index 00000000..3f13e68f --- /dev/null +++ b/tests/integration/test_valkey_memory_service_integration.py @@ -0,0 +1,252 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for ValkeyMemoryService. + +Requires a running Valkey instance. Set VALKEY_HOST and VALKEY_PORT +environment variables if not using defaults (localhost:6379). + +Run with: + pytest tests/integration/test_valkey_memory_service_integration.py -v +""" + +from __future__ import annotations + +import os +import uuid + +from google.adk.events.event import Event +from google.adk.sessions.session import Session +from google.adk_community.memory.valkey_memory_service import ( + ValkeyMemoryService, + ValkeyMemoryServiceConfig, +) +from google.genai import types +import pytest + +VALKEY_HOST = os.environ.get("VALKEY_HOST", "localhost") +VALKEY_PORT = int(os.environ.get("VALKEY_PORT", "6379")) + + +def _requires_valkey(): + """Check if Valkey is available, skip if not.""" + try: + import glide # noqa: F401 + except ImportError: + pytest.skip("valkey-glide not installed") + + +@pytest.fixture +async def valkey_client(): + """Create a connected valkey-glide client.""" + _requires_valkey() + from glide import GlideClient, GlideClientConfiguration, NodeAddress + + config = GlideClientConfiguration( + addresses=[NodeAddress(host=VALKEY_HOST, port=VALKEY_PORT)], + client_name="adk_memory_integration_test_client", + ) + client = await GlideClient.create(config) + yield client + await client.close() + + +@pytest.fixture +async def memory_service(valkey_client): + """Create ValkeyMemoryService with a unique prefix for test isolation.""" + test_prefix = f"test:memory:{uuid.uuid4().hex[:8]}" + config = ValkeyMemoryServiceConfig( + key_prefix=test_prefix, + search_top_k=10, + ) + service = ValkeyMemoryService(client=valkey_client, config=config) + yield service + + # Cleanup: delete test keys + list_key = f"{test_prefix}:*" + try: + # Use KEYS to find all test keys and delete them + keys = await valkey_client.custom_command(["KEYS", list_key]) + if keys: + for key in keys: + key_str = key.decode() if isinstance(key, bytes) else key + await valkey_client.custom_command(["DEL", key_str]) + except Exception: + pass + + +def _make_session(app_name: str, user_id: str) -> Session: + """Create a test session with events.""" + return Session( + app_name=app_name, + user_id=user_id, + id=f"session-{uuid.uuid4().hex[:8]}", + last_update_time=1000, + events=[ + Event( + id='event-1', + invocation_id='inv-1', + author='user', + timestamp=12345, + content=types.Content( + parts=[types.Part(text='I enjoy learning Python.')] + ), + ), + Event( + id='event-2', + invocation_id='inv-2', + author='model', + timestamp=12346, + content=types.Content( + parts=[ + types.Part( + text='Python is versatile and beginner-friendly.' + ) + ] + ), + ), + Event( + id='event-3', + invocation_id='inv-3', + author='user', + timestamp=12347, + content=types.Content( + parts=[ + types.Part( + text='What about Rust for systems programming?' + ) + ] + ), + ), + ], + ) + + +@pytest.mark.asyncio +class TestValkeyMemoryServiceIntegration: + """Integration tests for ValkeyMemoryService with a real Valkey instance.""" + + async def test_add_and_search_memories(self, memory_service): + """Test adding a session and searching for memories.""" + session = _make_session("test-app", "user-1") + + await memory_service.add_session_to_memory(session) + + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="Python", + ) + + assert len(result.memories) >= 1 + texts = [m.content.parts[0].text for m in result.memories] + assert any("Python" in t for t in texts) + + async def test_search_returns_empty_for_no_match(self, memory_service): + """Test that search returns empty when no memories match.""" + session = _make_session("test-app", "user-1") + await memory_service.add_session_to_memory(session) + + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="JavaScript framework", + ) + + assert len(result.memories) == 0 + + async def test_user_isolation(self, memory_service): + """Test that memories are isolated between users.""" + session1 = _make_session("test-app", "user-1") + session2 = Session( + app_name="test-app", + user_id="user-2", + id="session-other", + last_update_time=1000, + events=[ + Event( + id='event-other', + invocation_id='inv-other', + author='user', + timestamp=12345, + content=types.Content( + parts=[types.Part(text='I prefer Java over everything.')] + ), + ), + ], + ) + + await memory_service.add_session_to_memory(session1) + await memory_service.add_session_to_memory(session2) + + # user-1 should not see user-2's memories + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="Java", + ) + assert len(result.memories) == 0 + + # user-2 should see their own Java memory + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-2", + query="Java", + ) + assert len(result.memories) == 1 + + async def test_multiple_sessions_accumulate(self, memory_service): + """Test that multiple sessions accumulate memories.""" + session1 = _make_session("test-app", "user-1") + session2 = Session( + app_name="test-app", + user_id="user-1", + id="session-2", + last_update_time=2000, + events=[ + Event( + id='event-extra', + invocation_id='inv-extra', + author='user', + timestamp=22345, + content=types.Content( + parts=[ + types.Part(text='Python web frameworks are useful.') + ] + ), + ), + ], + ) + + await memory_service.add_session_to_memory(session1) + await memory_service.add_session_to_memory(session2) + + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="Python", + ) + + # Should find memories from both sessions + assert len(result.memories) >= 3 + + async def test_search_empty_store(self, memory_service): + """Test searching when no memories have been added.""" + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="anything", + ) + + assert len(result.memories) == 0 diff --git a/tests/unittests/memory/test_valkey_memory_service.py b/tests/unittests/memory/test_valkey_memory_service.py new file mode 100644 index 00000000..f7c464d4 --- /dev/null +++ b/tests/unittests/memory/test_valkey_memory_service.py @@ -0,0 +1,496 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest.mock import AsyncMock, MagicMock + +from google.adk.events.event import Event +from google.adk.sessions.session import Session +from google.adk_community.memory.valkey_memory_service import ( + ValkeyMemoryService, + ValkeyMemoryServiceConfig, +) +from google.genai import types +import pytest + +MOCK_APP_NAME = 'test-app' +MOCK_USER_ID = 'test-user' +MOCK_SESSION_ID = 'session-1' + +MOCK_SESSION = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id=MOCK_SESSION_ID, + last_update_time=1000, + events=[ + Event( + id='event-1', + invocation_id='inv-1', + author='user', + timestamp=12345, + content=types.Content( + parts=[types.Part(text='Hello, I like Python.')] + ), + ), + Event( + id='event-2', + invocation_id='inv-2', + author='model', + timestamp=12346, + content=types.Content( + parts=[ + types.Part(text='Python is a great programming language.') + ] + ), + ), + # Empty event, should be ignored + Event( + id='event-3', + invocation_id='inv-3', + author='user', + timestamp=12347, + ), + # Function call event, should be ignored + Event( + id='event-4', + invocation_id='inv-4', + author='agent', + timestamp=12348, + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall(name='test_function') + ) + ] + ), + ), + ], +) + +MOCK_SESSION_WITH_EMPTY_EVENTS = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id=MOCK_SESSION_ID, + last_update_time=1000, +) + + +@pytest.fixture +def mock_valkey_client(): + """Mock valkey-glide client for testing.""" + client = AsyncMock() + client.rpush = AsyncMock(return_value=1) + client.lrange = AsyncMock(return_value=[]) + client.expire = AsyncMock(return_value=True) + return client + + +@pytest.fixture +def memory_service(mock_valkey_client): + """Create ValkeyMemoryService instance for testing.""" + return ValkeyMemoryService(client=mock_valkey_client) + + +@pytest.fixture +def memory_service_with_config(mock_valkey_client): + """Create ValkeyMemoryService with custom config.""" + config = ValkeyMemoryServiceConfig( + search_top_k=5, + key_prefix="custom:mem", + ttl_seconds=3600, + ) + return ValkeyMemoryService(client=mock_valkey_client, config=config) + + +class TestValkeyMemoryServiceConfig: + """Tests for ValkeyMemoryServiceConfig.""" + + def test_default_config(self): + """Test default configuration values.""" + config = ValkeyMemoryServiceConfig() + assert config.search_top_k == 10 + assert config.key_prefix == "adk:memory" + assert config.ttl_seconds is None + + def test_custom_config(self): + """Test custom configuration values.""" + config = ValkeyMemoryServiceConfig( + search_top_k=20, + key_prefix="my:prefix", + ttl_seconds=7200, + ) + assert config.search_top_k == 20 + assert config.key_prefix == "my:prefix" + assert config.ttl_seconds == 7200 + + def test_config_validation_search_top_k(self): + """Test search_top_k validation.""" + with pytest.raises(Exception): + ValkeyMemoryServiceConfig(search_top_k=0) + + with pytest.raises(Exception): + ValkeyMemoryServiceConfig(search_top_k=101) + + +class TestValkeyMemoryServiceInit: + """Tests for ValkeyMemoryService initialization.""" + + def test_client_required(self): + """Test that client is required.""" + with pytest.raises(ValueError, match="client is required"): + ValkeyMemoryService(client=None) + + def test_init_with_client(self, mock_valkey_client): + """Test initialization with a valid client.""" + service = ValkeyMemoryService(client=mock_valkey_client) + assert service._client is mock_valkey_client + assert service._config.search_top_k == 10 + + def test_init_with_config(self, mock_valkey_client): + """Test initialization with custom config.""" + config = ValkeyMemoryServiceConfig(search_top_k=5) + service = ValkeyMemoryService( + client=mock_valkey_client, config=config + ) + assert service._config.search_top_k == 5 + + +class TestValkeyMemoryServiceAddSession: + """Tests for add_session_to_memory.""" + + @pytest.mark.asyncio + async def test_add_session_success( + self, memory_service, mock_valkey_client + ): + """Test successful addition of session memories.""" + await memory_service.add_session_to_memory(MOCK_SESSION) + + # Should make 2 rpush calls (one per valid event with text) + assert mock_valkey_client.rpush.call_count == 2 + + # Check first call + first_call = mock_valkey_client.rpush.call_args_list[0] + key = first_call[0][0] + assert key == "adk:memory:test-app:test-user:entries" + + value = first_call[0][1][0] + data = json.loads(value) + assert data["content"] == "Hello, I like Python." + assert data["author"] == "user" + assert data["session_id"] == MOCK_SESSION_ID + assert data["event_id"] == "event-1" + + # Check second call + second_call = mock_valkey_client.rpush.call_args_list[1] + value = second_call[0][1][0] + data = json.loads(value) + assert data["content"] == "Python is a great programming language." + assert data["author"] == "model" + + @pytest.mark.asyncio + async def test_add_session_filters_empty_events( + self, memory_service, mock_valkey_client + ): + """Test that events without text content are filtered out.""" + await memory_service.add_session_to_memory( + MOCK_SESSION_WITH_EMPTY_EVENTS + ) + assert mock_valkey_client.rpush.call_count == 0 + + @pytest.mark.asyncio + async def test_add_session_with_ttl( + self, memory_service_with_config, mock_valkey_client + ): + """Test that TTL is set when configured.""" + await memory_service_with_config.add_session_to_memory(MOCK_SESSION) + + mock_valkey_client.expire.assert_called_once_with( + "custom:mem:test-app:test-user:entries", 3600 + ) + + @pytest.mark.asyncio + async def test_add_session_no_ttl_by_default( + self, memory_service, mock_valkey_client + ): + """Test that no TTL is set when not configured.""" + await memory_service.add_session_to_memory(MOCK_SESSION) + mock_valkey_client.expire.assert_not_called() + + @pytest.mark.asyncio + async def test_add_session_error_handling( + self, memory_service, mock_valkey_client + ): + """Test error handling during memory addition.""" + mock_valkey_client.rpush.side_effect = Exception("Connection error") + + # Should not raise exception, just log error + await memory_service.add_session_to_memory(MOCK_SESSION) + assert mock_valkey_client.rpush.call_count == 2 + + @pytest.mark.asyncio + async def test_add_session_custom_key_prefix( + self, memory_service_with_config, mock_valkey_client + ): + """Test that custom key prefix is used.""" + await memory_service_with_config.add_session_to_memory(MOCK_SESSION) + + first_call = mock_valkey_client.rpush.call_args_list[0] + key = first_call[0][0] + assert key == "custom:mem:test-app:test-user:entries" + + +class TestValkeyMemoryServiceSearch: + """Tests for search_memory.""" + + @pytest.mark.asyncio + async def test_search_memory_success( + self, memory_service, mock_valkey_client + ): + """Test successful memory search.""" + stored_memories = [ + json.dumps({ + "content": "I love Python programming", + "author": "user", + "timestamp": 12345, + }).encode(), + json.dumps({ + "content": "Java is also popular", + "author": "model", + "timestamp": 12346, + }).encode(), + json.dumps({ + "content": "Python has great libraries", + "author": "user", + "timestamp": 12347, + }).encode(), + ] + mock_valkey_client.lrange = AsyncMock(return_value=stored_memories) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="Python", + ) + + assert len(result.memories) == 2 + assert result.memories[0].content.parts[0].text == ( + "I love Python programming" + ) + assert result.memories[0].author == "user" + assert result.memories[1].content.parts[0].text == ( + "Python has great libraries" + ) + + @pytest.mark.asyncio + async def test_search_memory_no_results( + self, memory_service, mock_valkey_client + ): + """Test search with no matching memories.""" + stored_memories = [ + json.dumps({ + "content": "Hello world", + "author": "user", + "timestamp": 12345, + }).encode(), + ] + mock_valkey_client.lrange = AsyncMock(return_value=stored_memories) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="Rust language", + ) + + assert len(result.memories) == 0 + + @pytest.mark.asyncio + async def test_search_memory_empty_store( + self, memory_service, mock_valkey_client + ): + """Test search when no memories are stored.""" + mock_valkey_client.lrange = AsyncMock(return_value=[]) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="anything", + ) + + assert len(result.memories) == 0 + + @pytest.mark.asyncio + async def test_search_memory_none_response( + self, memory_service, mock_valkey_client + ): + """Test search when lrange returns None.""" + mock_valkey_client.lrange = AsyncMock(return_value=None) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="anything", + ) + + assert len(result.memories) == 0 + + @pytest.mark.asyncio + async def test_search_memory_respects_top_k( + self, memory_service_with_config, mock_valkey_client + ): + """Test that search respects search_top_k config.""" + # Create more memories than top_k (5) + stored_memories = [ + json.dumps({ + "content": f"Python tip number {i}", + "author": "user", + "timestamp": 12345 + i, + }).encode() + for i in range(10) + ] + mock_valkey_client.lrange = AsyncMock(return_value=stored_memories) + + result = await memory_service_with_config.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="Python", + ) + + # Should return at most 5 (search_top_k) + assert len(result.memories) == 5 + + @pytest.mark.asyncio + async def test_search_memory_case_insensitive( + self, memory_service, mock_valkey_client + ): + """Test that search is case-insensitive.""" + stored_memories = [ + json.dumps({ + "content": "PYTHON is great", + "author": "user", + "timestamp": 12345, + }).encode(), + ] + mock_valkey_client.lrange = AsyncMock(return_value=stored_memories) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="python", + ) + + assert len(result.memories) == 1 + + @pytest.mark.asyncio + async def test_search_memory_error_handling( + self, memory_service, mock_valkey_client + ): + """Test graceful error handling during search.""" + mock_valkey_client.lrange.side_effect = Exception("Connection error") + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="test", + ) + + assert len(result.memories) == 0 + + @pytest.mark.asyncio + async def test_search_memory_handles_corrupt_entries( + self, memory_service, mock_valkey_client + ): + """Test that corrupt entries are skipped gracefully.""" + stored_memories = [ + b"not valid json", + json.dumps({ + "content": "Valid Python memory", + "author": "user", + "timestamp": 12345, + }).encode(), + ] + mock_valkey_client.lrange = AsyncMock(return_value=stored_memories) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="Python", + ) + + assert len(result.memories) == 1 + assert result.memories[0].content.parts[0].text == ( + "Valid Python memory" + ) + + @pytest.mark.asyncio + async def test_search_memory_multi_term_query( + self, memory_service, mock_valkey_client + ): + """Test search with multiple terms (any term matches).""" + stored_memories = [ + json.dumps({ + "content": "I love Python", + "author": "user", + "timestamp": 12345, + }).encode(), + json.dumps({ + "content": "Java is enterprise", + "author": "model", + "timestamp": 12346, + }).encode(), + json.dumps({ + "content": "Rust is fast", + "author": "user", + "timestamp": 12347, + }).encode(), + ] + mock_valkey_client.lrange = AsyncMock(return_value=stored_memories) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="Python Java", + ) + + # Both "Python" and "Java" memories should match + assert len(result.memories) == 2 + + @pytest.mark.asyncio + async def test_search_memory_correct_key( + self, memory_service, mock_valkey_client + ): + """Test that the correct Valkey key is queried.""" + mock_valkey_client.lrange = AsyncMock(return_value=[]) + + await memory_service.search_memory( + app_name="my-app", + user_id="user-123", + query="test", + ) + + mock_valkey_client.lrange.assert_called_once_with( + "adk:memory:my-app:user-123:entries", 0, -1 + ) + + +class TestValkeyMemoryServiceClose: + """Tests for close method.""" + + @pytest.mark.asyncio + async def test_close_does_not_close_client( + self, memory_service, mock_valkey_client + ): + """Test that close does not close the underlying client.""" + await memory_service.close() + # Client's close should NOT be called + mock_valkey_client.close.assert_not_called() From 2d69fba28fffc3ad5833dde3af665d3cc411da9c Mon Sep 17 00:00:00 2001 From: Daria Korenieva Date: Tue, 2 Jun 2026 11:36:28 -0700 Subject: [PATCH 2/7] feat: rewrite ValkeyMemoryService to use Valkey Search module Replaces the simple substring-matching implementation with full-text search powered by the Valkey Search module (FT.CREATE / FT.SEARCH). Changes: - Memories stored as Valkey Hash keys (indexed automatically) - FT.CREATE with TEXT field for content, TAG fields for app_name/user_id - FT.SEARCH for full-text search with TAG filtering - Expanded integration tests (11 tests covering isolation, top_k, etc.) - Added memory module README with usage documentation Ref: AEA-497 Signed-off-by: Daria Korenieva --- src/google/adk_community/memory/README.md | 91 ++++ .../memory/valkey_memory_service.py | 284 +++++++--- .../test_valkey_memory_service_integration.py | 366 +++++++++++-- .../memory/test_valkey_memory_service.py | 495 +++++++++--------- 4 files changed, 877 insertions(+), 359 deletions(-) create mode 100644 src/google/adk_community/memory/README.md diff --git a/src/google/adk_community/memory/README.md b/src/google/adk_community/memory/README.md new file mode 100644 index 00000000..7471eab4 --- /dev/null +++ b/src/google/adk_community/memory/README.md @@ -0,0 +1,91 @@ +# Memory Services + +Community-contributed memory service implementations for the +[Google ADK](https://google.github.io/adk-docs/) framework. + +## Available Services + +### ValkeyMemoryService + +A memory service backed by [Valkey](https://valkey.io/) using the +[Valkey Search module](https://valkey.io/topics/search/) for full-text +search. Uses the [valkey-glide](https://github.com/valkey-io/valkey-glide) +client library. + +**Features:** +- Full-text search powered by the Valkey Search module (FT.CREATE / FT.SEARCH) +- Memories stored as Valkey Hash keys with automatic indexing +- TAG-based filtering by `app_name` and `user_id` for scoped queries +- Configurable TTL for automatic memory expiration +- Case-insensitive search out of the box + +**Requirements:** +- Valkey server with the Search module loaded (e.g., + [valkey-bundle](https://hub.docker.com/r/valkey/valkey-bundle) image) +- `valkey-glide >= 2.4.0` + +**Installation:** + +```bash +pip install google-adk-community[valkey] +``` + +**Usage:** + +```python +from glide import GlideClient, GlideClientConfiguration, NodeAddress +from google.adk_community.memory import ValkeyMemoryService, ValkeyMemoryServiceConfig + +# Create a valkey-glide client +config = GlideClientConfiguration( + addresses=[NodeAddress(host="localhost", port=6379)], + client_name="my_adk_app", +) +client = await GlideClient.create(config) + +# Create the memory service +memory_config = ValkeyMemoryServiceConfig( + search_top_k=10, # Max results per search + key_prefix="adk:memory", # Valkey key prefix + index_name="adk_memory_idx", # Search index name + ttl_seconds=None, # Optional TTL (None = no expiry) +) +memory_service = ValkeyMemoryService(client=client, config=memory_config) + +# The index is created automatically on first use, or explicitly: +await memory_service.create_index() + +# Use with ADK runner +from google.adk.runners import Runner + +runner = Runner( + agent=my_agent, + memory_service=memory_service, + ... +) +``` + +**Running Valkey with Search module:** + +```bash +# Using podman +podman run -d --name valkey -p 6379:6379 valkey/valkey-bundle:9.1 + +# Using docker +docker run -d --name valkey -p 6379:6379 valkey/valkey-bundle:9.1 +``` + +--- + +### OpenMemoryService + +A memory service backed by [OpenMemory](https://openmemory.cavira.app/). +Uses HTTP API calls for memory storage and retrieval. + +**Installation:** + +```bash +pip install google-adk-community +``` + +See the `OpenMemoryService` class documentation for usage details. diff --git a/src/google/adk_community/memory/valkey_memory_service.py b/src/google/adk_community/memory/valkey_memory_service.py index c7e18560..84b98f7f 100644 --- a/src/google/adk_community/memory/valkey_memory_service.py +++ b/src/google/adk_community/memory/valkey_memory_service.py @@ -12,31 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Valkey-backed memory service for ADK using valkey-glide client.""" +"""Valkey-backed memory service for ADK using valkey-glide client. + +Uses the Valkey Search module (FT.CREATE / FT.SEARCH) for full-text +search over stored memories. +""" from __future__ import annotations -import json import logging import time from typing import Optional from typing import TYPE_CHECKING +import uuid +from google.adk.memory.base_memory_service import BaseMemoryService +from google.adk.memory.base_memory_service import SearchMemoryResponse +from google.adk.memory.memory_entry import MemoryEntry from google.genai import types from pydantic import BaseModel from pydantic import Field from typing_extensions import override -from google.adk.memory.base_memory_service import BaseMemoryService -from google.adk.memory.base_memory_service import SearchMemoryResponse -from google.adk.memory.memory_entry import MemoryEntry - from .utils import extract_text_from_event if TYPE_CHECKING: from google.adk.sessions.session import Session -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) class ValkeyMemoryServiceConfig(BaseModel): @@ -45,21 +48,28 @@ class ValkeyMemoryServiceConfig(BaseModel): Attributes: search_top_k: Maximum number of memories to retrieve per search. key_prefix: Prefix for all Valkey keys to avoid collisions. + index_name: Name of the Valkey Search index. ttl_seconds: Optional TTL for memory entries in seconds. None means no expiration. """ search_top_k: int = Field(default=10, ge=1, le=100) key_prefix: str = Field(default="adk:memory") + index_name: str = Field(default="adk_memory_idx") ttl_seconds: Optional[int] = Field(default=None, ge=1) class ValkeyMemoryService(BaseMemoryService): - """Memory service implementation using Valkey as the backend. + """Memory service implementation using Valkey with the Search module. + + Uses valkey-glide client for communication with Valkey server and the + Valkey Search module (FT.CREATE / FT.SEARCH) for full-text search + over stored memories. - Uses valkey-glide client for communication with Valkey server. - Memories are stored as JSON strings in Valkey lists, indexed by - app_name and user_id for efficient retrieval. + Memories are stored as Valkey Hash keys with fields: content, author, + timestamp, session_id, event_id, app_name, user_id, created_at. A + full-text search index is created over the content field, with TAG + fields for app_name and user_id to enable scoped queries. Example usage: @@ -71,6 +81,7 @@ class ValkeyMemoryService(BaseMemoryService): ) client = await GlideClient.create(config) service = ValkeyMemoryService(client=client) + await service.create_index() """ @@ -95,68 +106,148 @@ def __init__( ) self._client = client self._config = config or ValkeyMemoryServiceConfig() + self._index_created = False - def _memory_list_key(self, app_name: str, user_id: str) -> str: - """Generate the Valkey key for a user's memory list.""" - return f"{self._config.key_prefix}:{app_name}:{user_id}:entries" - - def _serialize_memory( - self, event, content_text: str, session - ) -> str: - """Serialize an event into a JSON string for storage.""" - memory_data = { - "content": content_text, - "author": event.author, - "timestamp": event.timestamp, - "session_id": session.id, - "event_id": event.id, - "app_name": session.app_name, - "user_id": session.user_id, - "created_at": time.time(), - } - return json.dumps(memory_data) + async def create_index(self): + """Create the Valkey Search index if it does not already exist. + + Creates a full-text search index with: + - content: TEXT field for full-text search + - app_name: TAG field for filtering by application + - user_id: TAG field for filtering by user + - author: TAG field for filtering by author + - timestamp: NUMERIC field for sorting + + This method is idempotent — if the index already exists, it + will log a debug message and return without error. + """ + from glide import DataType + from glide import ft + from glide import FtCreateOptions + from glide import NumericField + from glide import TagField + from glide import TextField + + schema = [ + TextField("content"), + TagField("app_name"), + TagField("user_id"), + TagField("author"), + NumericField("timestamp", sortable=True), + ] + + options = FtCreateOptions( + data_type=DataType.HASH, + prefixes=[f"{self._config.key_prefix}:"], + ) + + try: + await ft.create( + self._client, + self._config.index_name, + schema, + options, + ) + self._index_created = True + logger.info("Created search index: %s", self._config.index_name) + except Exception as e: + error_msg = str(e).lower() + if "index already exists" in error_msg or "exists" in error_msg: + self._index_created = True + logger.debug( + "Search index already exists: %s", + self._config.index_name, + ) + else: + raise + + def _memory_hash_key(self) -> str: + """Generate a unique Valkey hash key for a memory entry.""" + unique_id = uuid.uuid4().hex[:12] + return f"{self._config.key_prefix}:{unique_id}" @override async def add_session_to_memory(self, session: Session): - """Add a session's events to Valkey memory storage.""" + """Add a session's events to Valkey memory storage. + + Each event with text content is stored as a separate Valkey Hash + key with the configured prefix, making it automatically indexed + by the search module. + """ + if not self._index_created: + await self.create_index() + memories_added = 0 - list_key = self._memory_list_key(session.app_name, session.user_id) for event in session.events: content_text = extract_text_from_event(event) if not content_text: continue + hash_key = self._memory_hash_key() + field_values = { + "content": content_text, + "author": event.author or "", + "timestamp": str(event.timestamp) if event.timestamp else "0", + "session_id": session.id, + "event_id": event.id or "", + "app_name": session.app_name, + "user_id": session.user_id, + "created_at": str(time.time()), + } + try: - serialized = self._serialize_memory(event, content_text, session) - await self._client.rpush(list_key, [serialized]) + await self._client.hset(hash_key, field_values) memories_added += 1 - logger.debug("Added memory for event %s", event.id) - except Exception as e: - logger.error( - "Failed to add memory for event %s: %s", event.id, e - ) + logger.debug("Added memory for event %s at key %s", event.id, hash_key) - if self._config.ttl_seconds and memories_added > 0: - try: - await self._client.expire(list_key, self._config.ttl_seconds) + if self._config.ttl_seconds: + await self._client.expire(hash_key, self._config.ttl_seconds) except Exception as e: - logger.error("Failed to set TTL on key %s: %s", list_key, e) + logger.error("Failed to add memory for event %s: %s", event.id, e) - logger.info( - "Added %d memories from session %s", memories_added, session.id - ) + logger.info("Added %d memories from session %s", memories_added, session.id) + + def _build_search_query(self, app_name: str, user_id: str, query: str) -> str: + """Build an FT.SEARCH query string with filters. + + Constructs a query that: + - Filters by app_name and user_id using TAG filters + - Searches content using full-text search + + Args: + app_name: Application name filter. + user_id: User ID filter. + query: The user's search query text. + + Returns: + A Valkey Search query string. + """ + # Escape special characters in TAG values + escaped_app = app_name.replace("-", "\\-") + escaped_user = user_id.replace("-", "\\-") + + # Build full-text query with TAG filters + # Use @field:{value} for TAG filtering and plain text for content + tag_filter = f"@app_name:{{{escaped_app}}} @user_id:{{{escaped_user}}}" + + # Escape special FT.SEARCH characters in the query text + search_chars = r'@!{}()|-=><~*:;$["\]^' + escaped_query = query + for ch in search_chars: + escaped_query = escaped_query.replace(ch, f"\\{ch}") + + return f"{tag_filter} {escaped_query}" @override async def search_memory( self, *, app_name: str, user_id: str, query: str ) -> SearchMemoryResponse: - """Search for memories matching the query. + """Search for memories matching the query using Valkey Search. - Performs a simple text-based search over stored memories for - the given app and user. Retrieves all stored memories and - filters them by checking if the query terms appear in the - content. + Uses FT.SEARCH with the Valkey Search module for full-text search. + Results are filtered by app_name and user_id, and the query is + matched against the content field. Args: app_name: The application name to scope the search. @@ -166,58 +257,79 @@ async def search_memory( Returns: SearchMemoryResponse containing matching MemoryEntry objects. """ - list_key = self._memory_list_key(app_name, user_id) + from glide import ft + from glide import FtSearchLimit + from glide import FtSearchOptions + + if not self._index_created: + await self.create_index() try: - # Retrieve all memories for this user/app - raw_memories = await self._client.lrange(list_key, 0, -1) + search_query = self._build_search_query(app_name, user_id, query) + options = FtSearchOptions( + limit=FtSearchLimit(0, self._config.search_top_k), + ) - if not raw_memories: + result = await ft.search( + self._client, + self._config.index_name, + search_query, + options, + ) + + if not result or len(result) < 2: + return SearchMemoryResponse(memories=[]) + + # result is [count, {doc_id: {field: value, ...}, ...}] + doc_count = result[0] + if doc_count == 0: return SearchMemoryResponse(memories=[]) memories = [] - query_lower = query.lower() - query_terms = query_lower.split() + doc_map = result[1] if len(result) > 1 else {} - for raw in raw_memories: + for doc_id, fields in doc_map.items(): try: - raw_str = ( - raw.decode("utf-8") if isinstance(raw, bytes) else raw + content_text = self._decode(fields.get(b"content", b"")) + if not content_text: + continue + + author = self._decode(fields.get(b"author", b"")) or None + timestamp_raw = self._decode(fields.get(b"timestamp", b"0")) + # Numeric fields may return "12345.0"; normalize to int string + if timestamp_raw and timestamp_raw != "0": + try: + timestamp = str(int(float(timestamp_raw))) + except (ValueError, TypeError): + timestamp = timestamp_raw + else: + timestamp = None + + content = types.Content(parts=[types.Part(text=content_text)]) + entry = MemoryEntry( + content=content, + author=author, + timestamp=timestamp, ) - memory_data = json.loads(raw_str) - content_text = memory_data.get("content", "") - - # Simple term-matching search - content_lower = content_text.lower() - if any(term in content_lower for term in query_terms): - content = types.Content( - parts=[types.Part(text=content_text)] - ) - timestamp = memory_data.get("timestamp") - if timestamp is not None: - timestamp = str(timestamp) - entry = MemoryEntry( - content=content, - author=memory_data.get("author"), - timestamp=timestamp, - ) - memories.append(entry) - - if len(memories) >= self._config.search_top_k: - break - except (json.JSONDecodeError, KeyError) as e: - logger.debug("Failed to parse memory entry: %s", e) + memories.append(entry) + except Exception as e: + logger.debug("Failed to parse search result: %s", e) continue - logger.info( - "Found %d memories for query: '%s'", len(memories), query - ) + logger.info("Found %d memories for query: '%s'", len(memories), query) return SearchMemoryResponse(memories=memories) except Exception as e: logger.error("Failed to search memories: %s", e) return SearchMemoryResponse(memories=[]) + @staticmethod + def _decode(value) -> str: + """Decode bytes to string if needed.""" + if isinstance(value, bytes): + return value.decode("utf-8") + return str(value) if value is not None else "" + async def close(self): """Close the memory service. diff --git a/tests/integration/test_valkey_memory_service_integration.py b/tests/integration/test_valkey_memory_service_integration.py index 3f13e68f..b1e899c1 100644 --- a/tests/integration/test_valkey_memory_service_integration.py +++ b/tests/integration/test_valkey_memory_service_integration.py @@ -14,27 +14,31 @@ """Integration tests for ValkeyMemoryService. -Requires a running Valkey instance. Set VALKEY_HOST and VALKEY_PORT -environment variables if not using defaults (localhost:6379). +Requires a running Valkey instance with the Search module loaded. +Set VALKEY_HOST and VALKEY_PORT environment variables if not using +defaults (localhost:6379). -Run with: +The valkey-bundle image (valkey/valkey-bundle) includes the Search +module. Run with: + + podman run -d --name valkey-test -p 6379:6379 valkey/valkey-bundle:9.1 pytest tests/integration/test_valkey_memory_service_integration.py -v """ from __future__ import annotations +import asyncio import os import uuid from google.adk.events.event import Event from google.adk.sessions.session import Session -from google.adk_community.memory.valkey_memory_service import ( - ValkeyMemoryService, - ValkeyMemoryServiceConfig, -) from google.genai import types import pytest +from google.adk_community.memory.valkey_memory_service import ValkeyMemoryService +from google.adk_community.memory.valkey_memory_service import ValkeyMemoryServiceConfig + VALKEY_HOST = os.environ.get("VALKEY_HOST", "localhost") VALKEY_PORT = int(os.environ.get("VALKEY_PORT", "6379")) @@ -51,7 +55,9 @@ def _requires_valkey(): async def valkey_client(): """Create a connected valkey-glide client.""" _requires_valkey() - from glide import GlideClient, GlideClientConfiguration, NodeAddress + from glide import GlideClient + from glide import GlideClientConfiguration + from glide import NodeAddress config = GlideClientConfiguration( addresses=[NodeAddress(host=VALKEY_HOST, port=VALKEY_PORT)], @@ -65,19 +71,30 @@ async def valkey_client(): @pytest.fixture async def memory_service(valkey_client): """Create ValkeyMemoryService with a unique prefix for test isolation.""" + from glide import ft + test_prefix = f"test:memory:{uuid.uuid4().hex[:8]}" + index_name = f"test_idx_{uuid.uuid4().hex[:8]}" config = ValkeyMemoryServiceConfig( key_prefix=test_prefix, + index_name=index_name, search_top_k=10, ) service = ValkeyMemoryService(client=valkey_client, config=config) + await service.create_index() + + # Small delay for index to be ready + await asyncio.sleep(0.1) + yield service - # Cleanup: delete test keys - list_key = f"{test_prefix}:*" + # Cleanup: drop the index and delete test keys try: - # Use KEYS to find all test keys and delete them - keys = await valkey_client.custom_command(["KEYS", list_key]) + await ft.dropindex(valkey_client, index_name) + except Exception: + pass + try: + keys = await valkey_client.custom_command(["KEYS", f"{test_prefix}:*"]) if keys: for key in keys: key_str = key.decode() if isinstance(key, bytes) else key @@ -95,36 +112,36 @@ def _make_session(app_name: str, user_id: str) -> Session: last_update_time=1000, events=[ Event( - id='event-1', - invocation_id='inv-1', - author='user', + id="event-1", + invocation_id="inv-1", + author="user", timestamp=12345, content=types.Content( - parts=[types.Part(text='I enjoy learning Python.')] + parts=[types.Part(text="I enjoy learning Python.")] ), ), Event( - id='event-2', - invocation_id='inv-2', - author='model', + id="event-2", + invocation_id="inv-2", + author="model", timestamp=12346, content=types.Content( parts=[ types.Part( - text='Python is versatile and beginner-friendly.' + text="Python is versatile and beginner-friendly." ) ] ), ), Event( - id='event-3', - invocation_id='inv-3', - author='user', + id="event-3", + invocation_id="inv-3", + author="user", timestamp=12347, content=types.Content( parts=[ types.Part( - text='What about Rust for systems programming?' + text="What about Rust for systems programming?" ) ] ), @@ -142,6 +159,7 @@ async def test_add_and_search_memories(self, memory_service): session = _make_session("test-app", "user-1") await memory_service.add_session_to_memory(session) + await asyncio.sleep(0.5) # Wait for indexing result = await memory_service.search_memory( app_name="test-app", @@ -157,11 +175,12 @@ async def test_search_returns_empty_for_no_match(self, memory_service): """Test that search returns empty when no memories match.""" session = _make_session("test-app", "user-1") await memory_service.add_session_to_memory(session) + await asyncio.sleep(0.5) result = await memory_service.search_memory( app_name="test-app", user_id="user-1", - query="JavaScript framework", + query="JavaScript framework Angular", ) assert len(result.memories) == 0 @@ -176,12 +195,12 @@ async def test_user_isolation(self, memory_service): last_update_time=1000, events=[ Event( - id='event-other', - invocation_id='inv-other', - author='user', + id="event-other", + invocation_id="inv-other", + author="user", timestamp=12345, content=types.Content( - parts=[types.Part(text='I prefer Java over everything.')] + parts=[types.Part(text="I prefer Java over everything.")] ), ), ], @@ -189,6 +208,7 @@ async def test_user_isolation(self, memory_service): await memory_service.add_session_to_memory(session1) await memory_service.add_session_to_memory(session2) + await asyncio.sleep(0.5) # user-1 should not see user-2's memories result = await memory_service.search_memory( @@ -206,6 +226,63 @@ async def test_user_isolation(self, memory_service): ) assert len(result.memories) == 1 + async def test_app_isolation(self, memory_service): + """Test that memories are isolated between applications.""" + session1 = Session( + app_name="app-one", + user_id="user-1", + id="session-app1", + last_update_time=1000, + events=[ + Event( + id="event-app1", + invocation_id="inv-app1", + author="user", + timestamp=12345, + content=types.Content( + parts=[types.Part(text="Kubernetes orchestration tips.")] + ), + ), + ], + ) + session2 = Session( + app_name="app-two", + user_id="user-1", + id="session-app2", + last_update_time=1000, + events=[ + Event( + id="event-app2", + invocation_id="inv-app2", + author="user", + timestamp=12345, + content=types.Content( + parts=[types.Part(text="Docker container best practices.")] + ), + ), + ], + ) + + await memory_service.add_session_to_memory(session1) + await memory_service.add_session_to_memory(session2) + await asyncio.sleep(0.5) + + # app-one should only see its own memories + result = await memory_service.search_memory( + app_name="app-one", + user_id="user-1", + query="Kubernetes", + ) + assert len(result.memories) == 1 + + # app-two should not see app-one's memories + result = await memory_service.search_memory( + app_name="app-two", + user_id="user-1", + query="Kubernetes", + ) + assert len(result.memories) == 0 + async def test_multiple_sessions_accumulate(self, memory_service): """Test that multiple sessions accumulate memories.""" session1 = _make_session("test-app", "user-1") @@ -216,14 +293,12 @@ async def test_multiple_sessions_accumulate(self, memory_service): last_update_time=2000, events=[ Event( - id='event-extra', - invocation_id='inv-extra', - author='user', + id="event-extra", + invocation_id="inv-extra", + author="user", timestamp=22345, content=types.Content( - parts=[ - types.Part(text='Python web frameworks are useful.') - ] + parts=[types.Part(text="Python web frameworks are useful.")] ), ), ], @@ -231,6 +306,7 @@ async def test_multiple_sessions_accumulate(self, memory_service): await memory_service.add_session_to_memory(session1) await memory_service.add_session_to_memory(session2) + await asyncio.sleep(0.5) result = await memory_service.search_memory( app_name="test-app", @@ -250,3 +326,223 @@ async def test_search_empty_store(self, memory_service): ) assert len(result.memories) == 0 + + async def test_search_case_insensitive(self, memory_service): + """Test that full-text search is case-insensitive.""" + session = Session( + app_name="test-app", + user_id="user-1", + id="session-case", + last_update_time=1000, + events=[ + Event( + id="event-case", + invocation_id="inv-case", + author="user", + timestamp=12345, + content=types.Content( + parts=[ + types.Part( + text="VALKEY is a high performance datastore." + ) + ] + ), + ), + ], + ) + + await memory_service.add_session_to_memory(session) + await asyncio.sleep(0.5) + + # Search with lowercase should find uppercase content + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="valkey", + ) + assert len(result.memories) == 1 + + async def test_search_top_k_limit(self, valkey_client): + """Test that search_top_k limits the number of results.""" + from glide import ft + + test_prefix = f"test:topk:{uuid.uuid4().hex[:8]}" + index_name = f"test_topk_idx_{uuid.uuid4().hex[:8]}" + config = ValkeyMemoryServiceConfig( + key_prefix=test_prefix, + index_name=index_name, + search_top_k=3, + ) + service = ValkeyMemoryService(client=valkey_client, config=config) + await service.create_index() + await asyncio.sleep(0.1) + + # Add more events than top_k + events = [ + Event( + id=f"event-{i}", + invocation_id=f"inv-{i}", + author="user", + timestamp=12345 + i, + content=types.Content( + parts=[types.Part(text=f"Python tip number {i} is great.")] + ), + ) + for i in range(6) + ] + session = Session( + app_name="test-app", + user_id="user-1", + id="session-topk", + last_update_time=1000, + events=events, + ) + + await service.add_session_to_memory(session) + await asyncio.sleep(0.5) + + result = await service.search_memory( + app_name="test-app", + user_id="user-1", + query="Python", + ) + + # Should return at most 3 (search_top_k) + assert len(result.memories) <= 3 + assert len(result.memories) >= 1 + + # Cleanup + try: + await ft.dropindex(valkey_client, index_name) + except Exception: + pass + try: + keys = await valkey_client.custom_command(["KEYS", f"{test_prefix}:*"]) + if keys: + for key in keys: + key_str = key.decode() if isinstance(key, bytes) else key + await valkey_client.custom_command(["DEL", key_str]) + except Exception: + pass + + async def test_events_without_text_are_filtered(self, memory_service): + """Test that function_call and empty events are not stored.""" + session = Session( + app_name="test-app", + user_id="user-1", + id="session-filter", + last_update_time=1000, + events=[ + # Function call event - should be filtered + Event( + id="event-func", + invocation_id="inv-func", + author="agent", + timestamp=12345, + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall(name="search_tool") + ) + ] + ), + ), + # Empty event - should be filtered + Event( + id="event-empty", + invocation_id="inv-empty", + author="user", + timestamp=12346, + ), + # Valid text event - should be stored + Event( + id="event-text", + invocation_id="inv-text", + author="user", + timestamp=12347, + content=types.Content( + parts=[types.Part(text="This is valid content.")] + ), + ), + ], + ) + + await memory_service.add_session_to_memory(session) + await asyncio.sleep(0.5) + + # Only the text event should be searchable + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="valid content", + ) + assert len(result.memories) == 1 + + # Function call content should not appear + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="search_tool", + ) + assert len(result.memories) == 0 + + async def test_memory_entry_metadata(self, memory_service): + """Test that returned MemoryEntry has correct metadata.""" + session = Session( + app_name="test-app", + user_id="user-1", + id="session-meta", + last_update_time=1000, + events=[ + Event( + id="event-meta", + invocation_id="inv-meta", + author="user", + timestamp=99999, + content=types.Content( + parts=[types.Part(text="Metadata verification test.")] + ), + ), + ], + ) + + await memory_service.add_session_to_memory(session) + await asyncio.sleep(0.5) + + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="Metadata verification", + ) + + assert len(result.memories) == 1 + entry = result.memories[0] + assert entry.content.parts[0].text == "Metadata verification test." + assert entry.author == "user" + assert entry.timestamp == "99999" + + async def test_create_index_idempotent(self, valkey_client): + """Test that calling create_index multiple times is safe.""" + from glide import ft + + index_name = f"test_idem_idx_{uuid.uuid4().hex[:8]}" + test_prefix = f"test:idem:{uuid.uuid4().hex[:8]}" + config = ValkeyMemoryServiceConfig( + key_prefix=test_prefix, + index_name=index_name, + ) + service = ValkeyMemoryService(client=valkey_client, config=config) + + # First call should succeed + await service.create_index() + assert service._index_created is True + + # Second call should not raise + await service.create_index() + assert service._index_created is True + + # Cleanup + try: + await ft.dropindex(valkey_client, index_name) + except Exception: + pass diff --git a/tests/unittests/memory/test_valkey_memory_service.py b/tests/unittests/memory/test_valkey_memory_service.py index f7c464d4..602b3ea5 100644 --- a/tests/unittests/memory/test_valkey_memory_service.py +++ b/tests/unittests/memory/test_valkey_memory_service.py @@ -12,21 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock +from unittest.mock import patch from google.adk.events.event import Event from google.adk.sessions.session import Session -from google.adk_community.memory.valkey_memory_service import ( - ValkeyMemoryService, - ValkeyMemoryServiceConfig, -) from google.genai import types import pytest -MOCK_APP_NAME = 'test-app' -MOCK_USER_ID = 'test-user' -MOCK_SESSION_ID = 'session-1' +from google.adk_community.memory.valkey_memory_service import ValkeyMemoryService +from google.adk_community.memory.valkey_memory_service import ValkeyMemoryServiceConfig + +MOCK_APP_NAME = "test-app" +MOCK_USER_ID = "test-user" +MOCK_SESSION_ID = "session-1" MOCK_SESSION = Session( app_name=MOCK_APP_NAME, @@ -35,42 +34,42 @@ last_update_time=1000, events=[ Event( - id='event-1', - invocation_id='inv-1', - author='user', + id="event-1", + invocation_id="inv-1", + author="user", timestamp=12345, content=types.Content( - parts=[types.Part(text='Hello, I like Python.')] + parts=[types.Part(text="Hello, I like Python.")] ), ), Event( - id='event-2', - invocation_id='inv-2', - author='model', + id="event-2", + invocation_id="inv-2", + author="model", timestamp=12346, content=types.Content( parts=[ - types.Part(text='Python is a great programming language.') + types.Part(text="Python is a great programming language.") ] ), ), # Empty event, should be ignored Event( - id='event-3', - invocation_id='inv-3', - author='user', + id="event-3", + invocation_id="inv-3", + author="user", timestamp=12347, ), # Function call event, should be ignored Event( - id='event-4', - invocation_id='inv-4', - author='agent', + id="event-4", + invocation_id="inv-4", + author="agent", timestamp=12348, content=types.Content( parts=[ types.Part( - function_call=types.FunctionCall(name='test_function') + function_call=types.FunctionCall(name="test_function") ) ] ), @@ -90,8 +89,7 @@ def mock_valkey_client(): """Mock valkey-glide client for testing.""" client = AsyncMock() - client.rpush = AsyncMock(return_value=1) - client.lrange = AsyncMock(return_value=[]) + client.hset = AsyncMock(return_value=1) client.expire = AsyncMock(return_value=True) return client @@ -99,7 +97,9 @@ def mock_valkey_client(): @pytest.fixture def memory_service(mock_valkey_client): """Create ValkeyMemoryService instance for testing.""" - return ValkeyMemoryService(client=mock_valkey_client) + service = ValkeyMemoryService(client=mock_valkey_client) + service._index_created = True # Skip index creation in unit tests + return service @pytest.fixture @@ -108,9 +108,12 @@ def memory_service_with_config(mock_valkey_client): config = ValkeyMemoryServiceConfig( search_top_k=5, key_prefix="custom:mem", + index_name="custom_idx", ttl_seconds=3600, ) - return ValkeyMemoryService(client=mock_valkey_client, config=config) + service = ValkeyMemoryService(client=mock_valkey_client, config=config) + service._index_created = True + return service class TestValkeyMemoryServiceConfig: @@ -121,6 +124,7 @@ def test_default_config(self): config = ValkeyMemoryServiceConfig() assert config.search_top_k == 10 assert config.key_prefix == "adk:memory" + assert config.index_name == "adk_memory_idx" assert config.ttl_seconds is None def test_custom_config(self): @@ -128,10 +132,12 @@ def test_custom_config(self): config = ValkeyMemoryServiceConfig( search_top_k=20, key_prefix="my:prefix", + index_name="my_index", ttl_seconds=7200, ) assert config.search_top_k == 20 assert config.key_prefix == "my:prefix" + assert config.index_name == "my_index" assert config.ttl_seconds == 7200 def test_config_validation_search_top_k(self): @@ -156,57 +162,84 @@ def test_init_with_client(self, mock_valkey_client): service = ValkeyMemoryService(client=mock_valkey_client) assert service._client is mock_valkey_client assert service._config.search_top_k == 10 + assert service._index_created is False def test_init_with_config(self, mock_valkey_client): """Test initialization with custom config.""" config = ValkeyMemoryServiceConfig(search_top_k=5) - service = ValkeyMemoryService( - client=mock_valkey_client, config=config - ) + service = ValkeyMemoryService(client=mock_valkey_client, config=config) assert service._config.search_top_k == 5 +class TestValkeyMemoryServiceCreateIndex: + """Tests for create_index.""" + + @pytest.mark.asyncio + async def test_create_index_success(self, mock_valkey_client): + """Test successful index creation.""" + service = ValkeyMemoryService(client=mock_valkey_client) + + with patch("glide.ft.create", new_callable=AsyncMock) as mock_create: + mock_create.return_value = "OK" + await service.create_index() + + assert service._index_created is True + mock_create.assert_called_once() + + @pytest.mark.asyncio + async def test_create_index_already_exists(self, mock_valkey_client): + """Test that existing index is handled gracefully.""" + service = ValkeyMemoryService(client=mock_valkey_client) + + with patch("glide.ft.create", new_callable=AsyncMock) as mock_create: + mock_create.side_effect = Exception("Index already exists") + await service.create_index() + + assert service._index_created is True + + @pytest.mark.asyncio + async def test_create_index_unexpected_error(self, mock_valkey_client): + """Test that unexpected errors are raised.""" + service = ValkeyMemoryService(client=mock_valkey_client) + + with patch("glide.ft.create", new_callable=AsyncMock) as mock_create: + mock_create.side_effect = Exception("Connection refused") + with pytest.raises(Exception, match="Connection refused"): + await service.create_index() + + assert service._index_created is False + + class TestValkeyMemoryServiceAddSession: """Tests for add_session_to_memory.""" @pytest.mark.asyncio - async def test_add_session_success( - self, memory_service, mock_valkey_client - ): + async def test_add_session_success(self, memory_service, mock_valkey_client): """Test successful addition of session memories.""" await memory_service.add_session_to_memory(MOCK_SESSION) - # Should make 2 rpush calls (one per valid event with text) - assert mock_valkey_client.rpush.call_count == 2 + # Should make 2 hset calls (one per valid event with text) + assert mock_valkey_client.hset.call_count == 2 - # Check first call - first_call = mock_valkey_client.rpush.call_args_list[0] + # Check first call stores correct fields + first_call = mock_valkey_client.hset.call_args_list[0] key = first_call[0][0] - assert key == "adk:memory:test-app:test-user:entries" - - value = first_call[0][1][0] - data = json.loads(value) - assert data["content"] == "Hello, I like Python." - assert data["author"] == "user" - assert data["session_id"] == MOCK_SESSION_ID - assert data["event_id"] == "event-1" - - # Check second call - second_call = mock_valkey_client.rpush.call_args_list[1] - value = second_call[0][1][0] - data = json.loads(value) - assert data["content"] == "Python is a great programming language." - assert data["author"] == "model" + assert key.startswith("adk:memory:") + fields = first_call[0][1] + assert fields["content"] == "Hello, I like Python." + assert fields["author"] == "user" + assert fields["app_name"] == MOCK_APP_NAME + assert fields["user_id"] == MOCK_USER_ID + assert fields["session_id"] == MOCK_SESSION_ID + assert fields["event_id"] == "event-1" @pytest.mark.asyncio async def test_add_session_filters_empty_events( self, memory_service, mock_valkey_client ): """Test that events without text content are filtered out.""" - await memory_service.add_session_to_memory( - MOCK_SESSION_WITH_EMPTY_EVENTS - ) - assert mock_valkey_client.rpush.call_count == 0 + await memory_service.add_session_to_memory(MOCK_SESSION_WITH_EMPTY_EVENTS) + assert mock_valkey_client.hset.call_count == 0 @pytest.mark.asyncio async def test_add_session_with_ttl( @@ -215,9 +248,10 @@ async def test_add_session_with_ttl( """Test that TTL is set when configured.""" await memory_service_with_config.add_session_to_memory(MOCK_SESSION) - mock_valkey_client.expire.assert_called_once_with( - "custom:mem:test-app:test-user:entries", 3600 - ) + # Should set TTL on each hash key + assert mock_valkey_client.expire.call_count == 2 + expire_call = mock_valkey_client.expire.call_args_list[0] + assert expire_call[0][1] == 3600 @pytest.mark.asyncio async def test_add_session_no_ttl_by_default( @@ -232,11 +266,11 @@ async def test_add_session_error_handling( self, memory_service, mock_valkey_client ): """Test error handling during memory addition.""" - mock_valkey_client.rpush.side_effect = Exception("Connection error") + mock_valkey_client.hset.side_effect = Exception("Connection error") # Should not raise exception, just log error await memory_service.add_session_to_memory(MOCK_SESSION) - assert mock_valkey_client.rpush.call_count == 2 + assert mock_valkey_client.hset.call_count == 2 @pytest.mark.asyncio async def test_add_session_custom_key_prefix( @@ -245,9 +279,22 @@ async def test_add_session_custom_key_prefix( """Test that custom key prefix is used.""" await memory_service_with_config.add_session_to_memory(MOCK_SESSION) - first_call = mock_valkey_client.rpush.call_args_list[0] + first_call = mock_valkey_client.hset.call_args_list[0] key = first_call[0][0] - assert key == "custom:mem:test-app:test-user:entries" + assert key.startswith("custom:mem:") + + @pytest.mark.asyncio + async def test_add_session_creates_index_if_needed(self, mock_valkey_client): + """Test that create_index is called if not yet created.""" + service = ValkeyMemoryService(client=mock_valkey_client) + assert service._index_created is False + + with patch("glide.ft.create", new_callable=AsyncMock) as mock_create: + mock_create.return_value = "OK" + await service.add_session_to_memory(MOCK_SESSION) + + mock_create.assert_called_once() + assert service._index_created is True class TestValkeyMemoryServiceSearch: @@ -257,230 +304,203 @@ class TestValkeyMemoryServiceSearch: async def test_search_memory_success( self, memory_service, mock_valkey_client ): - """Test successful memory search.""" - stored_memories = [ - json.dumps({ - "content": "I love Python programming", - "author": "user", - "timestamp": 12345, - }).encode(), - json.dumps({ - "content": "Java is also popular", - "author": "model", - "timestamp": 12346, - }).encode(), - json.dumps({ - "content": "Python has great libraries", - "author": "user", - "timestamp": 12347, - }).encode(), + """Test successful memory search using FT.SEARCH.""" + search_result = [ + 2, + { + b"adk:memory:abc123": { + b"content": b"I love Python programming", + b"author": b"user", + b"timestamp": b"12345", + }, + b"adk:memory:def456": { + b"content": b"Python has great libraries", + b"author": b"model", + b"timestamp": b"12346", + }, + }, ] - mock_valkey_client.lrange = AsyncMock(return_value=stored_memories) - result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query="Python", - ) + with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: + mock_search.return_value = search_result - assert len(result.memories) == 2 - assert result.memories[0].content.parts[0].text == ( - "I love Python programming" - ) - assert result.memories[0].author == "user" - assert result.memories[1].content.parts[0].text == ( - "Python has great libraries" - ) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="Python", + ) + + assert len(result.memories) == 2 + assert ( + result.memories[0].content.parts[0].text + == "I love Python programming" + ) + assert result.memories[0].author == "user" + assert result.memories[0].timestamp == "12345" @pytest.mark.asyncio async def test_search_memory_no_results( self, memory_service, mock_valkey_client ): """Test search with no matching memories.""" - stored_memories = [ - json.dumps({ - "content": "Hello world", - "author": "user", - "timestamp": 12345, - }).encode(), - ] - mock_valkey_client.lrange = AsyncMock(return_value=stored_memories) + with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: + mock_search.return_value = [0, {}] - result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query="Rust language", - ) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="Rust language", + ) - assert len(result.memories) == 0 + assert len(result.memories) == 0 @pytest.mark.asyncio - async def test_search_memory_empty_store( + async def test_search_memory_empty_result( self, memory_service, mock_valkey_client ): - """Test search when no memories are stored.""" - mock_valkey_client.lrange = AsyncMock(return_value=[]) + """Test search when FT.SEARCH returns empty.""" + with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: + mock_search.return_value = [0, {}] - result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query="anything", - ) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="anything", + ) - assert len(result.memories) == 0 + assert len(result.memories) == 0 @pytest.mark.asyncio - async def test_search_memory_none_response( + async def test_search_memory_uses_correct_query( self, memory_service, mock_valkey_client ): - """Test search when lrange returns None.""" - mock_valkey_client.lrange = AsyncMock(return_value=None) - - result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query="anything", - ) - - assert len(result.memories) == 0 + """Test that search builds correct FT.SEARCH query.""" + with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: + mock_search.return_value = [0, {}] + + await memory_service.search_memory( + app_name="my-app", + user_id="user-123", + query="test query", + ) + + call_args = mock_search.call_args + query_str = call_args[0][2] + assert "my\\-app" in query_str + assert "user\\-123" in query_str + assert "test query" in query_str @pytest.mark.asyncio async def test_search_memory_respects_top_k( self, memory_service_with_config, mock_valkey_client ): - """Test that search respects search_top_k config.""" - # Create more memories than top_k (5) - stored_memories = [ - json.dumps({ - "content": f"Python tip number {i}", - "author": "user", - "timestamp": 12345 + i, - }).encode() - for i in range(10) - ] - mock_valkey_client.lrange = AsyncMock(return_value=stored_memories) - - result = await memory_service_with_config.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query="Python", - ) - - # Should return at most 5 (search_top_k) - assert len(result.memories) == 5 - - @pytest.mark.asyncio - async def test_search_memory_case_insensitive( - self, memory_service, mock_valkey_client - ): - """Test that search is case-insensitive.""" - stored_memories = [ - json.dumps({ - "content": "PYTHON is great", - "author": "user", - "timestamp": 12345, - }).encode(), - ] - mock_valkey_client.lrange = AsyncMock(return_value=stored_memories) + """Test that search uses search_top_k in LIMIT.""" + with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: + mock_search.return_value = [0, {}] - result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query="python", - ) + await memory_service_with_config.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="Python", + ) - assert len(result.memories) == 1 + call_args = mock_search.call_args + options = call_args[0][3] + assert options.limit is not None @pytest.mark.asyncio async def test_search_memory_error_handling( self, memory_service, mock_valkey_client ): """Test graceful error handling during search.""" - mock_valkey_client.lrange.side_effect = Exception("Connection error") + with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: + mock_search.side_effect = Exception("Connection error") - result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query="test", - ) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="test", + ) - assert len(result.memories) == 0 + assert len(result.memories) == 0 @pytest.mark.asyncio - async def test_search_memory_handles_corrupt_entries( + async def test_search_memory_handles_missing_fields( self, memory_service, mock_valkey_client ): - """Test that corrupt entries are skipped gracefully.""" - stored_memories = [ - b"not valid json", - json.dumps({ - "content": "Valid Python memory", - "author": "user", - "timestamp": 12345, - }).encode(), + """Test that entries with missing content are skipped.""" + search_result = [ + 2, + { + b"adk:memory:abc123": { + b"content": b"", + b"author": b"user", + b"timestamp": b"12345", + }, + b"adk:memory:def456": { + b"content": b"Valid memory", + b"author": b"model", + b"timestamp": b"12346", + }, + }, ] - mock_valkey_client.lrange = AsyncMock(return_value=stored_memories) - result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query="Python", - ) + with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: + mock_search.return_value = search_result - assert len(result.memories) == 1 - assert result.memories[0].content.parts[0].text == ( - "Valid Python memory" - ) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="test", + ) + + assert len(result.memories) == 1 + assert result.memories[0].content.parts[0].text == "Valid memory" @pytest.mark.asyncio - async def test_search_memory_multi_term_query( - self, memory_service, mock_valkey_client + async def test_search_memory_creates_index_if_needed( + self, mock_valkey_client ): - """Test search with multiple terms (any term matches).""" - stored_memories = [ - json.dumps({ - "content": "I love Python", - "author": "user", - "timestamp": 12345, - }).encode(), - json.dumps({ - "content": "Java is enterprise", - "author": "model", - "timestamp": 12346, - }).encode(), - json.dumps({ - "content": "Rust is fast", - "author": "user", - "timestamp": 12347, - }).encode(), - ] - mock_valkey_client.lrange = AsyncMock(return_value=stored_memories) + """Test that search creates index if not yet created.""" + service = ValkeyMemoryService(client=mock_valkey_client) - result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query="Python Java", - ) + with ( + patch("glide.ft.create", new_callable=AsyncMock) as mock_create, + patch("glide.ft.search", new_callable=AsyncMock) as mock_search, + ): + mock_create.return_value = "OK" + mock_search.return_value = [0, {}] - # Both "Python" and "Java" memories should match - assert len(result.memories) == 2 + await service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="test", + ) - @pytest.mark.asyncio - async def test_search_memory_correct_key( - self, memory_service, mock_valkey_client - ): - """Test that the correct Valkey key is queried.""" - mock_valkey_client.lrange = AsyncMock(return_value=[]) + mock_create.assert_called_once() - await memory_service.search_memory( - app_name="my-app", - user_id="user-123", - query="test", - ) - mock_valkey_client.lrange.assert_called_once_with( - "adk:memory:my-app:user-123:entries", 0, -1 - ) +class TestValkeyMemoryServiceBuildQuery: + """Tests for _build_search_query.""" + + def test_basic_query(self, memory_service): + """Test basic query construction.""" + query = memory_service._build_search_query("myapp", "user1", "hello world") + assert "@app_name:{myapp}" in query + assert "@user_id:{user1}" in query + assert "hello world" in query + + def test_hyphenated_values(self, memory_service): + """Test escaping of hyphens in TAG values.""" + query = memory_service._build_search_query("my-app", "user-1", "test") + assert "my\\-app" in query + assert "user\\-1" in query + + def test_special_chars_in_query(self, memory_service): + """Test escaping of special search chars in query text.""" + query = memory_service._build_search_query("app", "user", "hello @world") + # @ should be escaped in the query text + assert "\\@world" in query class TestValkeyMemoryServiceClose: @@ -492,5 +512,4 @@ async def test_close_does_not_close_client( ): """Test that close does not close the underlying client.""" await memory_service.close() - # Client's close should NOT be called mock_valkey_client.close.assert_not_called() From c0762a8adfda85ef7b6d9de56886e486aff4bf18 Mon Sep 17 00:00:00 2001 From: Daria Korenieva Date: Tue, 2 Jun 2026 11:52:49 -0700 Subject: [PATCH 3/7] feat: rewrite ValkeyMemoryService with vector similarity search (HNSW) Replaces full-text search with vector similarity search powered by the Valkey Search module, matching VertexAiRagMemoryService in functionality. Key changes: - Configurable embedding function (users bring their own embedder) - FT.CREATE with VECTOR field (HNSW, FLOAT32) for KNN search - FT.SEARCH with KNN pre-filtered by app_name/user_id TAG fields - Configurable vector_distance_threshold for filtering low-quality matches - Configurable distance_metric (COSINE, L2, IP) - Batch embedding generation for efficient ingestion - Implements add_events_to_memory for incremental ingestion - 30 unit tests, 12 integration tests (all passing) Ref: AEA-497 Signed-off-by: Daria Korenieva --- src/google/adk_community/memory/README.md | 74 +++- .../memory/valkey_memory_service.py | 286 +++++++++++++--- .../test_valkey_memory_service_integration.py | 278 ++++++++++----- .../memory/test_valkey_memory_service.py | 323 +++++++++++------- 4 files changed, 687 insertions(+), 274 deletions(-) diff --git a/src/google/adk_community/memory/README.md b/src/google/adk_community/memory/README.md index 7471eab4..03792a6c 100644 --- a/src/google/adk_community/memory/README.md +++ b/src/google/adk_community/memory/README.md @@ -8,21 +8,26 @@ Community-contributed memory service implementations for the ### ValkeyMemoryService A memory service backed by [Valkey](https://valkey.io/) using the -[Valkey Search module](https://valkey.io/topics/search/) for full-text -search. Uses the [valkey-glide](https://github.com/valkey-io/valkey-glide) -client library. +[Valkey Search module](https://valkey.io/topics/search/) for vector +similarity search. Uses the [valkey-glide](https://github.com/valkey-io/valkey-glide) +client library. This provides functionality analogous to +`VertexAiRagMemoryService` for users with Valkey infrastructure. **Features:** -- Full-text search powered by the Valkey Search module (FT.CREATE / FT.SEARCH) -- Memories stored as Valkey Hash keys with automatic indexing -- TAG-based filtering by `app_name` and `user_id` for scoped queries -- Configurable TTL for automatic memory expiration -- Case-insensitive search out of the box +- Vector similarity search (HNSW) powered by the Valkey Search module +- Configurable embedding function (bring your own: OpenAI, Gemini, sentence-transformers, etc.) +- KNN search with pre-filtering by `app_name` and `user_id` TAG fields +- Configurable `vector_distance_threshold` for filtering low-quality matches +- Configurable distance metric (COSINE, L2, IP) +- Optional TTL for automatic memory expiration +- Batch embedding generation for efficient ingestion +- Supports `add_session_to_memory` and `add_events_to_memory` **Requirements:** - Valkey server with the Search module loaded (e.g., [valkey-bundle](https://hub.docker.com/r/valkey/valkey-bundle) image) - `valkey-glide >= 2.4.0` +- An embedding function (async callable) **Installation:** @@ -36,26 +41,45 @@ pip install google-adk-community[valkey] from glide import GlideClient, GlideClientConfiguration, NodeAddress from google.adk_community.memory import ValkeyMemoryService, ValkeyMemoryServiceConfig -# Create a valkey-glide client +# 1. Create a valkey-glide client config = GlideClientConfiguration( addresses=[NodeAddress(host="localhost", port=6379)], client_name="my_adk_app", ) client = await GlideClient.create(config) -# Create the memory service +# 2. Define your embedding function (bring your own model) +# Example with Google Gemini: +from google import genai +genai_client = genai.Client() + +async def embed_texts(texts: list[str]) -> list[list[float]]: + response = await genai_client.models.embed_content_async( + model="text-embedding-004", + contents=texts, + ) + return [e.values for e in response.embeddings] + +# 3. Create the memory service memory_config = ValkeyMemoryServiceConfig( - search_top_k=10, # Max results per search - key_prefix="adk:memory", # Valkey key prefix - index_name="adk_memory_idx", # Search index name - ttl_seconds=None, # Optional TTL (None = no expiry) + similarity_top_k=10, # Max results per search (KNN) + vector_distance_threshold=0.6, # Filter distant results (optional) + embedding_dimensions=768, # Must match your embedding model + key_prefix="adk:memory", # Valkey key prefix + index_name="adk_memory_idx", # Search index name + distance_metric="COSINE", # COSINE, L2, or IP + ttl_seconds=None, # Optional TTL (None = no expiry) +) +memory_service = ValkeyMemoryService( + client=client, + embedding_function=embed_texts, + config=memory_config, ) -memory_service = ValkeyMemoryService(client=client, config=memory_config) # The index is created automatically on first use, or explicitly: await memory_service.create_index() -# Use with ADK runner +# 4. Use with ADK Runner from google.adk.runners import Runner runner = Runner( @@ -65,6 +89,16 @@ runner = Runner( ) ``` +**How it works:** + +1. When `add_session_to_memory` is called, text is extracted from session + events, embeddings are generated in batch using your embedding function, + and each event is stored as a Valkey Hash with the embedding vector. + +2. When `search_memory` is called, an embedding is generated for the query, + then `FT.SEARCH` performs a KNN search with pre-filtering by `app_name` + and `user_id` TAG fields. Results are returned ranked by vector similarity. + **Running Valkey with Search module:** ```bash @@ -75,12 +109,18 @@ podman run -d --name valkey -p 6379:6379 valkey/valkey-bundle:9.1 docker run -d --name valkey -p 6379:6379 valkey/valkey-bundle:9.1 ``` +**Note on Redis Session Service:** The existing `RedisSessionService` in this +repo is wire-protocol compatible with Valkey. You can point it at a Valkey +instance directly for session storage without needing a separate Valkey +session service. + --- ### OpenMemoryService A memory service backed by [OpenMemory](https://openmemory.cavira.app/). -Uses HTTP API calls for memory storage and retrieval. +Uses HTTP API calls for memory storage and retrieval with LLM-powered +memory extraction. **Installation:** diff --git a/src/google/adk_community/memory/valkey_memory_service.py b/src/google/adk_community/memory/valkey_memory_service.py index 84b98f7f..5034a88e 100644 --- a/src/google/adk_community/memory/valkey_memory_service.py +++ b/src/google/adk_community/memory/valkey_memory_service.py @@ -14,13 +14,17 @@ """Valkey-backed memory service for ADK using valkey-glide client. -Uses the Valkey Search module (FT.CREATE / FT.SEARCH) for full-text -search over stored memories. +Uses the Valkey Search module with vector similarity search (HNSW) +for semantic memory retrieval, analogous to VertexAiRagMemoryService. """ from __future__ import annotations +from collections.abc import Awaitable +from collections.abc import Callable +from collections.abc import Sequence import logging +import struct import time from typing import Optional from typing import TYPE_CHECKING @@ -41,35 +45,51 @@ logger = logging.getLogger("google_adk." + __name__) +# Type alias for the embedding function. +# It takes a list of text strings and returns a list of float vectors. +EmbeddingFunction = Callable[[list[str]], Awaitable[list[list[float]]]] + class ValkeyMemoryServiceConfig(BaseModel): """Configuration for ValkeyMemoryService. Attributes: - search_top_k: Maximum number of memories to retrieve per search. + similarity_top_k: Maximum number of memories to retrieve per + search (KNN parameter). + vector_distance_threshold: Maximum distance threshold for + filtering results. Results with distance greater than this + are excluded. None means no threshold filtering. + embedding_dimensions: Dimensionality of the embedding vectors. key_prefix: Prefix for all Valkey keys to avoid collisions. index_name: Name of the Valkey Search index. - ttl_seconds: Optional TTL for memory entries in seconds. None means - no expiration. + distance_metric: Distance metric for vector similarity. + One of 'COSINE', 'L2', or 'IP' (inner product). + ttl_seconds: Optional TTL for memory entries in seconds. + None means no expiration. """ - search_top_k: int = Field(default=10, ge=1, le=100) + similarity_top_k: int = Field(default=10, ge=1, le=1000) + vector_distance_threshold: Optional[float] = Field(default=None, ge=0.0) + embedding_dimensions: int = Field(default=768, ge=1) key_prefix: str = Field(default="adk:memory") index_name: str = Field(default="adk_memory_idx") + distance_metric: str = Field(default="COSINE") ttl_seconds: Optional[int] = Field(default=None, ge=1) class ValkeyMemoryService(BaseMemoryService): - """Memory service implementation using Valkey with the Search module. + """Memory service using Valkey Search module with vector similarity. Uses valkey-glide client for communication with Valkey server and the - Valkey Search module (FT.CREATE / FT.SEARCH) for full-text search - over stored memories. + Valkey Search module for vector-based semantic search over stored + memories. This provides functionality analogous to + VertexAiRagMemoryService but backed by Valkey infrastructure. Memories are stored as Valkey Hash keys with fields: content, author, - timestamp, session_id, event_id, app_name, user_id, created_at. A - full-text search index is created over the content field, with TAG - fields for app_name and user_id to enable scoped queries. + timestamp, session_id, event_id, app_name, user_id, created_at, and + an embedding vector field. A vector search index (HNSW) is created + for approximate nearest neighbor retrieval, with TAG fields for + app_name and user_id to enable scoped queries. Example usage: @@ -80,7 +100,15 @@ class ValkeyMemoryService(BaseMemoryService): client_name="adk_memory_client", ) client = await GlideClient.create(config) - service = ValkeyMemoryService(client=client) + + async def my_embed_fn(texts: list[str]) -> list[list[float]]: + # Your embedding logic here (OpenAI, Gemini, etc.) + ... + + service = ValkeyMemoryService( + client=client, + embedding_function=my_embed_fn, + ) await service.create_index() """ @@ -88,6 +116,7 @@ class ValkeyMemoryService(BaseMemoryService): def __init__( self, client, + embedding_function: EmbeddingFunction, config: Optional[ValkeyMemoryServiceConfig] = None, ): """Initializes the Valkey memory service. @@ -96,6 +125,10 @@ def __init__( client: A connected valkey-glide GlideClient or GlideClusterClient instance. The caller is responsible for creating and managing the client lifecycle. + embedding_function: An async callable that takes a list of + text strings and returns a list of embedding vectors + (list of floats). Users provide their own embedding + model (e.g., OpenAI, Google Gemini, sentence-transformers). config: Optional ValkeyMemoryServiceConfig instance. If None, uses defaults. """ @@ -104,15 +137,22 @@ def __init__( "client is required. Provide a connected valkey-glide " "GlideClient or GlideClusterClient instance." ) + if embedding_function is None: + raise ValueError( + "embedding_function is required. Provide an async callable " + "that takes list[str] and returns list[list[float]]." + ) self._client = client + self._embedding_function = embedding_function self._config = config or ValkeyMemoryServiceConfig() self._index_created = False async def create_index(self): """Create the Valkey Search index if it does not already exist. - Creates a full-text search index with: - - content: TEXT field for full-text search + Creates a vector search index (HNSW) with: + - embedding: VECTOR field (HNSW, FLOAT32) for similarity search + - content: TEXT field for optional full-text filtering - app_name: TAG field for filtering by application - user_id: TAG field for filtering by user - author: TAG field for filtering by author @@ -122,18 +162,39 @@ async def create_index(self): will log a debug message and return without error. """ from glide import DataType + from glide import DistanceMetricType from glide import ft from glide import FtCreateOptions - from glide import NumericField from glide import TagField from glide import TextField + from glide import VectorAlgorithm + from glide import VectorField + from glide import VectorFieldAttributesHnsw + from glide import VectorType + + distance_map = { + "COSINE": DistanceMetricType.COSINE, + "L2": DistanceMetricType.L2, + "IP": DistanceMetricType.IP, + } + distance_metric = distance_map.get( + self._config.distance_metric.upper(), DistanceMetricType.COSINE + ) schema = [ + VectorField( + "embedding", + algorithm=VectorAlgorithm.HNSW, + attributes=VectorFieldAttributesHnsw( + dimensions=self._config.embedding_dimensions, + distance_metric=distance_metric, + type=VectorType.FLOAT32, + ), + ), TextField("content"), TagField("app_name"), TagField("user_id"), TagField("author"), - NumericField("timestamp", sortable=True), ] options = FtCreateOptions( @@ -166,24 +227,56 @@ def _memory_hash_key(self) -> str: unique_id = uuid.uuid4().hex[:12] return f"{self._config.key_prefix}:{unique_id}" + @staticmethod + def _vector_to_bytes(vector: list[float]) -> bytes: + """Convert a list of floats to a binary blob for Valkey storage.""" + return struct.pack(f"<{len(vector)}f", *vector) + @override async def add_session_to_memory(self, session: Session): """Add a session's events to Valkey memory storage. - Each event with text content is stored as a separate Valkey Hash - key with the configured prefix, making it automatically indexed - by the search module. + Extracts text from session events, generates embeddings using the + configured embedding function, and stores each event as a Valkey + Hash with the embedding vector for later similarity search. """ if not self._index_created: await self.create_index() - memories_added = 0 - + # Collect texts and their corresponding events + texts = [] + valid_events = [] for event in session.events: content_text = extract_text_from_event(event) - if not content_text: - continue + if content_text: + texts.append(content_text) + valid_events.append(event) + + if not texts: + logger.debug("No text events to add from session %s", session.id) + return + + # Generate embeddings for all texts in one batch + try: + embeddings = await self._embedding_function(texts) + except Exception as e: + logger.error( + "Failed to generate embeddings for session %s: %s", + session.id, + e, + ) + return + if len(embeddings) != len(texts): + logger.error( + "Embedding function returned %d vectors for %d texts", + len(embeddings), + len(texts), + ) + return + + memories_added = 0 + for event, content_text, embedding in zip(valid_events, texts, embeddings): hash_key = self._memory_hash_key() field_values = { "content": content_text, @@ -194,6 +287,7 @@ async def add_session_to_memory(self, session: Session): "app_name": session.app_name, "user_id": session.user_id, "created_at": str(time.time()), + "embedding": self._vector_to_bytes(embedding), } try: @@ -208,46 +302,114 @@ async def add_session_to_memory(self, session: Session): logger.info("Added %d memories from session %s", memories_added, session.id) - def _build_search_query(self, app_name: str, user_id: str, query: str) -> str: - """Build an FT.SEARCH query string with filters. + @override + async def add_events_to_memory( + self, + *, + app_name: str, + user_id: str, + events: Sequence, + session_id: str | None = None, + custom_metadata=None, + ) -> None: + """Adds an incremental list of events to memory. + + Generates embeddings and stores each event with text content. + This is useful for persisting only a subset of events (e.g., + the latest turn) without re-ingesting the full session. + + Args: + app_name: The application name for memory scope. + user_id: The user ID for memory scope. + events: The events to add to memory. + session_id: Optional session ID for partitioning. + custom_metadata: Optional metadata (unused currently). + """ + if not self._index_created: + await self.create_index() + + texts = [] + valid_events = [] + for event in events: + content_text = extract_text_from_event(event) + if content_text: + texts.append(content_text) + valid_events.append(event) + + if not texts: + return - Constructs a query that: - - Filters by app_name and user_id using TAG filters - - Searches content using full-text search + try: + embeddings = await self._embedding_function(texts) + except Exception as e: + logger.error("Failed to generate embeddings: %s", e) + return + + if len(embeddings) != len(texts): + logger.error( + "Embedding function returned %d vectors for %d texts", + len(embeddings), + len(texts), + ) + return + + memories_added = 0 + for event, content_text, embedding in zip(valid_events, texts, embeddings): + hash_key = self._memory_hash_key() + field_values = { + "content": content_text, + "author": event.author or "", + "timestamp": str(event.timestamp) if event.timestamp else "0", + "session_id": session_id or "", + "event_id": event.id or "", + "app_name": app_name, + "user_id": user_id, + "created_at": str(time.time()), + "embedding": self._vector_to_bytes(embedding), + } + + try: + await self._client.hset(hash_key, field_values) + memories_added += 1 + + if self._config.ttl_seconds: + await self._client.expire(hash_key, self._config.ttl_seconds) + except Exception as e: + logger.error("Failed to add memory for event %s: %s", event.id, e) + + logger.info("Added %d memories via add_events_to_memory", memories_added) + + def _build_knn_query(self, app_name: str, user_id: str, top_k: int) -> str: + """Build a KNN search query with TAG pre-filters. Args: app_name: Application name filter. user_id: User ID filter. - query: The user's search query text. + top_k: Number of nearest neighbors to retrieve. Returns: - A Valkey Search query string. + A Valkey Search KNN query string. """ - # Escape special characters in TAG values escaped_app = app_name.replace("-", "\\-") escaped_user = user_id.replace("-", "\\-") - # Build full-text query with TAG filters - # Use @field:{value} for TAG filtering and plain text for content - tag_filter = f"@app_name:{{{escaped_app}}} @user_id:{{{escaped_user}}}" - - # Escape special FT.SEARCH characters in the query text - search_chars = r'@!{}()|-=><~*:;$["\]^' - escaped_query = query - for ch in search_chars: - escaped_query = escaped_query.replace(ch, f"\\{ch}") - - return f"{tag_filter} {escaped_query}" + # KNN query with pre-filter: filter first, then KNN on results + return ( + f"(@app_name:{{{escaped_app}}} " + f"@user_id:{{{escaped_user}}})" + f"=>[KNN {top_k} @embedding $query_vec]" + ) @override async def search_memory( self, *, app_name: str, user_id: str, query: str ) -> SearchMemoryResponse: - """Search for memories matching the query using Valkey Search. + """Search for memories using vector similarity (KNN). - Uses FT.SEARCH with the Valkey Search module for full-text search. - Results are filtered by app_name and user_id, and the query is - matched against the content field. + Generates an embedding for the query text, then performs a KNN + search using FT.SEARCH with pre-filtering by app_name and + user_id. Results are ranked by vector distance (lower = more + similar for COSINE). Args: app_name: The application name to scope the search. @@ -255,19 +417,31 @@ async def search_memory( query: The search query string. Returns: - SearchMemoryResponse containing matching MemoryEntry objects. + SearchMemoryResponse containing matching MemoryEntry objects, + ordered by similarity. """ from glide import ft - from glide import FtSearchLimit from glide import FtSearchOptions if not self._index_created: await self.create_index() + # Generate embedding for the query try: - search_query = self._build_search_query(app_name, user_id, query) + query_embeddings = await self._embedding_function([query]) + query_embedding = query_embeddings[0] + except Exception as e: + logger.error("Failed to generate query embedding: %s", e) + return SearchMemoryResponse(memories=[]) + + query_vec_bytes = self._vector_to_bytes(query_embedding) + + try: + search_query = self._build_knn_query( + app_name, user_id, self._config.similarity_top_k + ) options = FtSearchOptions( - limit=FtSearchLimit(0, self._config.search_top_k), + params={"query_vec": query_vec_bytes}, ) result = await ft.search( @@ -280,7 +454,6 @@ async def search_memory( if not result or len(result) < 2: return SearchMemoryResponse(memories=[]) - # result is [count, {doc_id: {field: value, ...}, ...}] doc_count = result[0] if doc_count == 0: return SearchMemoryResponse(memories=[]) @@ -290,13 +463,20 @@ async def search_memory( for doc_id, fields in doc_map.items(): try: + # Check distance threshold if configured + if self._config.vector_distance_threshold is not None: + score_raw = self._decode(fields.get(b"__embedding_score", b"")) + if score_raw: + distance = float(score_raw) + if distance > self._config.vector_distance_threshold: + continue + content_text = self._decode(fields.get(b"content", b"")) if not content_text: continue author = self._decode(fields.get(b"author", b"")) or None timestamp_raw = self._decode(fields.get(b"timestamp", b"0")) - # Numeric fields may return "12345.0"; normalize to int string if timestamp_raw and timestamp_raw != "0": try: timestamp = str(int(float(timestamp_raw))) diff --git a/tests/integration/test_valkey_memory_service_integration.py b/tests/integration/test_valkey_memory_service_integration.py index b1e899c1..2ac5e71f 100644 --- a/tests/integration/test_valkey_memory_service_integration.py +++ b/tests/integration/test_valkey_memory_service_integration.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Integration tests for ValkeyMemoryService. +"""Integration tests for ValkeyMemoryService with vector similarity search. Requires a running Valkey instance with the Search module loaded. Set VALKEY_HOST and VALKEY_PORT environment variables if not using defaults (localhost:6379). -The valkey-bundle image (valkey/valkey-bundle) includes the Search -module. Run with: +The valkey-bundle image includes the Search module with vector support. +Run with: podman run -d --name valkey-test -p 6379:6379 valkey/valkey-bundle:9.1 pytest tests/integration/test_valkey_memory_service_integration.py -v @@ -28,6 +28,7 @@ from __future__ import annotations import asyncio +import math import os import uuid @@ -42,6 +43,33 @@ VALKEY_HOST = os.environ.get("VALKEY_HOST", "localhost") VALKEY_PORT = int(os.environ.get("VALKEY_PORT", "6379")) +# Simple deterministic embedding function for testing. +# Maps text to a 32-dim vector based on character frequencies. +EMBED_DIM = 32 + + +async def _test_embedding_function( + texts: list[str], +) -> list[list[float]]: + """Deterministic embedding function for testing. + + Generates a 32-dimensional vector based on character frequency + distribution. This gives semantically similar texts somewhat + similar vectors (texts with similar character distributions). + """ + embeddings = [] + for text in texts: + text_lower = text.lower() + vec = [0.0] * EMBED_DIM + for ch in text_lower: + idx = ord(ch) % EMBED_DIM + vec[idx] += 1.0 + # Normalize to unit vector + magnitude = math.sqrt(sum(x * x for x in vec)) or 1.0 + vec = [x / magnitude for x in vec] + embeddings.append(vec) + return embeddings + def _requires_valkey(): """Check if Valkey is available, skip if not.""" @@ -78,17 +106,20 @@ async def memory_service(valkey_client): config = ValkeyMemoryServiceConfig( key_prefix=test_prefix, index_name=index_name, - search_top_k=10, + similarity_top_k=10, + embedding_dimensions=EMBED_DIM, + ) + service = ValkeyMemoryService( + client=valkey_client, + embedding_function=_test_embedding_function, + config=config, ) - service = ValkeyMemoryService(client=valkey_client, config=config) await service.create_index() - - # Small delay for index to be ready await asyncio.sleep(0.1) yield service - # Cleanup: drop the index and delete test keys + # Cleanup try: await ft.dropindex(valkey_client, index_name) except Exception: @@ -152,38 +183,43 @@ def _make_session(app_name: str, user_id: str) -> Session: @pytest.mark.asyncio class TestValkeyMemoryServiceIntegration: - """Integration tests for ValkeyMemoryService with a real Valkey instance.""" + """Integration tests with vector similarity search.""" async def test_add_and_search_memories(self, memory_service): - """Test adding a session and searching for memories.""" + """Test adding a session and searching with vector similarity.""" session = _make_session("test-app", "user-1") - await memory_service.add_session_to_memory(session) - await asyncio.sleep(0.5) # Wait for indexing + await asyncio.sleep(0.5) result = await memory_service.search_memory( app_name="test-app", user_id="user-1", - query="Python", + query="Python programming language", ) assert len(result.memories) >= 1 texts = [m.content.parts[0].text for m in result.memories] assert any("Python" in t for t in texts) - async def test_search_returns_empty_for_no_match(self, memory_service): - """Test that search returns empty when no memories match.""" + async def test_search_returns_results_ranked_by_similarity( + self, memory_service + ): + """Test that results are returned by vector similarity.""" session = _make_session("test-app", "user-1") await memory_service.add_session_to_memory(session) await asyncio.sleep(0.5) + # Query similar to "Python" content should return Python memories result = await memory_service.search_memory( app_name="test-app", user_id="user-1", - query="JavaScript framework Angular", + query="I enjoy learning Python programming", ) - assert len(result.memories) == 0 + assert len(result.memories) >= 1 + # The most similar result should be about Python + top_text = result.memories[0].content.parts[0].text + assert "Python" in top_text or "python" in top_text.lower() async def test_user_isolation(self, memory_service): """Test that memories are isolated between users.""" @@ -210,21 +246,23 @@ async def test_user_isolation(self, memory_service): await memory_service.add_session_to_memory(session2) await asyncio.sleep(0.5) - # user-1 should not see user-2's memories + # user-1 should not see user-2's Java memory result = await memory_service.search_memory( app_name="test-app", user_id="user-1", - query="Java", + query="Java programming", ) - assert len(result.memories) == 0 + texts = [m.content.parts[0].text for m in result.memories] + assert not any("Java" in t for t in texts) # user-2 should see their own Java memory result = await memory_service.search_memory( app_name="test-app", user_id="user-2", - query="Java", + query="Java programming", ) assert len(result.memories) == 1 + assert "Java" in result.memories[0].content.parts[0].text async def test_app_isolation(self, memory_service): """Test that memories are isolated between applications.""" @@ -271,17 +309,20 @@ async def test_app_isolation(self, memory_service): result = await memory_service.search_memory( app_name="app-one", user_id="user-1", - query="Kubernetes", + query="Kubernetes orchestration", ) assert len(result.memories) == 1 - # app-two should not see app-one's memories + # app-two should only see its own Docker memory, not Kubernetes result = await memory_service.search_memory( app_name="app-two", user_id="user-1", - query="Kubernetes", + query="Kubernetes orchestration", ) - assert len(result.memories) == 0 + # KNN returns nearest neighbor from the filtered set (app-two only). + # The result should NOT contain app-one's Kubernetes content. + for mem in result.memories: + assert "Kubernetes" not in mem.content.parts[0].text async def test_multiple_sessions_accumulate(self, memory_service): """Test that multiple sessions accumulate memories.""" @@ -311,7 +352,7 @@ async def test_multiple_sessions_accumulate(self, memory_service): result = await memory_service.search_memory( app_name="test-app", user_id="user-1", - query="Python", + query="Python programming", ) # Should find memories from both sessions @@ -322,48 +363,12 @@ async def test_search_empty_store(self, memory_service): result = await memory_service.search_memory( app_name="test-app", user_id="user-1", - query="anything", + query="anything at all", ) - assert len(result.memories) == 0 - async def test_search_case_insensitive(self, memory_service): - """Test that full-text search is case-insensitive.""" - session = Session( - app_name="test-app", - user_id="user-1", - id="session-case", - last_update_time=1000, - events=[ - Event( - id="event-case", - invocation_id="inv-case", - author="user", - timestamp=12345, - content=types.Content( - parts=[ - types.Part( - text="VALKEY is a high performance datastore." - ) - ] - ), - ), - ], - ) - - await memory_service.add_session_to_memory(session) - await asyncio.sleep(0.5) - - # Search with lowercase should find uppercase content - result = await memory_service.search_memory( - app_name="test-app", - user_id="user-1", - query="valkey", - ) - assert len(result.memories) == 1 - - async def test_search_top_k_limit(self, valkey_client): - """Test that search_top_k limits the number of results.""" + async def test_similarity_top_k_limit(self, valkey_client): + """Test that similarity_top_k limits the number of results.""" from glide import ft test_prefix = f"test:topk:{uuid.uuid4().hex[:8]}" @@ -371,13 +376,17 @@ async def test_search_top_k_limit(self, valkey_client): config = ValkeyMemoryServiceConfig( key_prefix=test_prefix, index_name=index_name, - search_top_k=3, + similarity_top_k=3, + embedding_dimensions=EMBED_DIM, + ) + service = ValkeyMemoryService( + client=valkey_client, + embedding_function=_test_embedding_function, + config=config, ) - service = ValkeyMemoryService(client=valkey_client, config=config) await service.create_index() await asyncio.sleep(0.1) - # Add more events than top_k events = [ Event( id=f"event-{i}", @@ -404,10 +413,9 @@ async def test_search_top_k_limit(self, valkey_client): result = await service.search_memory( app_name="test-app", user_id="user-1", - query="Python", + query="Python tips", ) - # Should return at most 3 (search_top_k) assert len(result.memories) <= 3 assert len(result.memories) >= 1 @@ -433,7 +441,6 @@ async def test_events_without_text_are_filtered(self, memory_service): id="session-filter", last_update_time=1000, events=[ - # Function call event - should be filtered Event( id="event-func", invocation_id="inv-func", @@ -447,14 +454,12 @@ async def test_events_without_text_are_filtered(self, memory_service): ] ), ), - # Empty event - should be filtered Event( id="event-empty", invocation_id="inv-empty", author="user", timestamp=12346, ), - # Valid text event - should be stored Event( id="event-text", invocation_id="inv-text", @@ -470,22 +475,13 @@ async def test_events_without_text_are_filtered(self, memory_service): await memory_service.add_session_to_memory(session) await asyncio.sleep(0.5) - # Only the text event should be searchable result = await memory_service.search_memory( app_name="test-app", user_id="user-1", - query="valid content", + query="valid content text", ) assert len(result.memories) == 1 - # Function call content should not appear - result = await memory_service.search_memory( - app_name="test-app", - user_id="user-1", - query="search_tool", - ) - assert len(result.memories) == 0 - async def test_memory_entry_metadata(self, memory_service): """Test that returned MemoryEntry has correct metadata.""" session = Session( @@ -512,7 +508,7 @@ async def test_memory_entry_metadata(self, memory_service): result = await memory_service.search_memory( app_name="test-app", user_id="user-1", - query="Metadata verification", + query="Metadata verification test", ) assert len(result.memories) == 1 @@ -530,10 +526,14 @@ async def test_create_index_idempotent(self, valkey_client): config = ValkeyMemoryServiceConfig( key_prefix=test_prefix, index_name=index_name, + embedding_dimensions=EMBED_DIM, + ) + service = ValkeyMemoryService( + client=valkey_client, + embedding_function=_test_embedding_function, + config=config, ) - service = ValkeyMemoryService(client=valkey_client, config=config) - # First call should succeed await service.create_index() assert service._index_created is True @@ -546,3 +546,109 @@ async def test_create_index_idempotent(self, valkey_client): await ft.dropindex(valkey_client, index_name) except Exception: pass + + async def test_add_events_to_memory(self, memory_service): + """Test incremental event ingestion via add_events_to_memory.""" + events = [ + Event( + id="event-inc-1", + invocation_id="inv-inc-1", + author="user", + timestamp=50001, + content=types.Content( + parts=[types.Part(text="Incremental memory about Golang.")] + ), + ), + Event( + id="event-inc-2", + invocation_id="inv-inc-2", + author="model", + timestamp=50002, + content=types.Content( + parts=[types.Part(text="Go is great for concurrency.")] + ), + ), + ] + + await memory_service.add_events_to_memory( + app_name="test-app", + user_id="user-1", + events=events, + session_id="session-incremental", + ) + await asyncio.sleep(0.5) + + result = await memory_service.search_memory( + app_name="test-app", + user_id="user-1", + query="Golang concurrency", + ) + assert len(result.memories) >= 1 + texts = [m.content.parts[0].text for m in result.memories] + assert any("Go" in t for t in texts) + + async def test_vector_distance_threshold(self, valkey_client): + """Test that vector_distance_threshold filters distant results.""" + from glide import ft + + test_prefix = f"test:thresh:{uuid.uuid4().hex[:8]}" + index_name = f"test_thresh_idx_{uuid.uuid4().hex[:8]}" + config = ValkeyMemoryServiceConfig( + key_prefix=test_prefix, + index_name=index_name, + similarity_top_k=10, + embedding_dimensions=EMBED_DIM, + vector_distance_threshold=0.01, # Very strict threshold + ) + service = ValkeyMemoryService( + client=valkey_client, + embedding_function=_test_embedding_function, + config=config, + ) + await service.create_index() + await asyncio.sleep(0.1) + + session = Session( + app_name="test-app", + user_id="user-1", + id="session-thresh", + last_update_time=1000, + events=[ + Event( + id="event-thresh", + invocation_id="inv-thresh", + author="user", + timestamp=12345, + content=types.Content( + parts=[types.Part(text="Completely unrelated topic XYZ.")] + ), + ), + ], + ) + + await service.add_session_to_memory(session) + await asyncio.sleep(0.5) + + # Search for something very different — should be filtered by threshold + result = await service.search_memory( + app_name="test-app", + user_id="user-1", + query="AAAAAAA BBBBBBB CCCCCCC", + ) + # With strict threshold, dissimilar results should be filtered + # (This depends on the embedding function producing distant vectors) + assert len(result.memories) <= 1 + + # Cleanup + try: + await ft.dropindex(valkey_client, index_name) + except Exception: + pass + try: + keys = await valkey_client.custom_command(["KEYS", f"{test_prefix}:*"]) + if keys: + for key in keys: + key_str = key.decode() if isinstance(key, bytes) else key + await valkey_client.custom_command(["DEL", key_str]) + except Exception: + pass diff --git a/tests/unittests/memory/test_valkey_memory_service.py b/tests/unittests/memory/test_valkey_memory_service.py index 602b3ea5..9cb63b80 100644 --- a/tests/unittests/memory/test_valkey_memory_service.py +++ b/tests/unittests/memory/test_valkey_memory_service.py @@ -27,6 +27,12 @@ MOCK_USER_ID = "test-user" MOCK_SESSION_ID = "session-1" + +async def _mock_embed_fn(texts: list[str]) -> list[list[float]]: + """Simple mock embedding function returning fixed-dim vectors.""" + return [[0.1] * 768 for _ in texts] + + MOCK_SESSION = Session( app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, @@ -97,8 +103,11 @@ def mock_valkey_client(): @pytest.fixture def memory_service(mock_valkey_client): """Create ValkeyMemoryService instance for testing.""" - service = ValkeyMemoryService(client=mock_valkey_client) - service._index_created = True # Skip index creation in unit tests + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + ) + service._index_created = True return service @@ -106,12 +115,18 @@ def memory_service(mock_valkey_client): def memory_service_with_config(mock_valkey_client): """Create ValkeyMemoryService with custom config.""" config = ValkeyMemoryServiceConfig( - search_top_k=5, + similarity_top_k=5, key_prefix="custom:mem", index_name="custom_idx", ttl_seconds=3600, + embedding_dimensions=768, + vector_distance_threshold=0.5, + ) + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + config=config, ) - service = ValkeyMemoryService(client=mock_valkey_client, config=config) service._index_created = True return service @@ -122,31 +137,40 @@ class TestValkeyMemoryServiceConfig: def test_default_config(self): """Test default configuration values.""" config = ValkeyMemoryServiceConfig() - assert config.search_top_k == 10 + assert config.similarity_top_k == 10 assert config.key_prefix == "adk:memory" assert config.index_name == "adk_memory_idx" assert config.ttl_seconds is None + assert config.embedding_dimensions == 768 + assert config.distance_metric == "COSINE" + assert config.vector_distance_threshold is None def test_custom_config(self): """Test custom configuration values.""" config = ValkeyMemoryServiceConfig( - search_top_k=20, + similarity_top_k=20, key_prefix="my:prefix", index_name="my_index", ttl_seconds=7200, + embedding_dimensions=1536, + distance_metric="L2", + vector_distance_threshold=0.8, ) - assert config.search_top_k == 20 + assert config.similarity_top_k == 20 assert config.key_prefix == "my:prefix" assert config.index_name == "my_index" assert config.ttl_seconds == 7200 + assert config.embedding_dimensions == 1536 + assert config.distance_metric == "L2" + assert config.vector_distance_threshold == 0.8 - def test_config_validation_search_top_k(self): - """Test search_top_k validation.""" + def test_config_validation_top_k(self): + """Test similarity_top_k validation.""" with pytest.raises(Exception): - ValkeyMemoryServiceConfig(search_top_k=0) + ValkeyMemoryServiceConfig(similarity_top_k=0) with pytest.raises(Exception): - ValkeyMemoryServiceConfig(search_top_k=101) + ValkeyMemoryServiceConfig(similarity_top_k=1001) class TestValkeyMemoryServiceInit: @@ -155,20 +179,32 @@ class TestValkeyMemoryServiceInit: def test_client_required(self): """Test that client is required.""" with pytest.raises(ValueError, match="client is required"): - ValkeyMemoryService(client=None) - - def test_init_with_client(self, mock_valkey_client): - """Test initialization with a valid client.""" - service = ValkeyMemoryService(client=mock_valkey_client) + ValkeyMemoryService(client=None, embedding_function=_mock_embed_fn) + + def test_embedding_function_required(self, mock_valkey_client): + """Test that embedding_function is required.""" + with pytest.raises(ValueError, match="embedding_function is required"): + ValkeyMemoryService(client=mock_valkey_client, embedding_function=None) + + def test_init_with_defaults(self, mock_valkey_client): + """Test initialization with default config.""" + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + ) assert service._client is mock_valkey_client - assert service._config.search_top_k == 10 + assert service._config.similarity_top_k == 10 assert service._index_created is False def test_init_with_config(self, mock_valkey_client): """Test initialization with custom config.""" - config = ValkeyMemoryServiceConfig(search_top_k=5) - service = ValkeyMemoryService(client=mock_valkey_client, config=config) - assert service._config.search_top_k == 5 + config = ValkeyMemoryServiceConfig(similarity_top_k=5) + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + config=config, + ) + assert service._config.similarity_top_k == 5 class TestValkeyMemoryServiceCreateIndex: @@ -177,7 +213,10 @@ class TestValkeyMemoryServiceCreateIndex: @pytest.mark.asyncio async def test_create_index_success(self, mock_valkey_client): """Test successful index creation.""" - service = ValkeyMemoryService(client=mock_valkey_client) + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + ) with patch("glide.ft.create", new_callable=AsyncMock) as mock_create: mock_create.return_value = "OK" @@ -189,24 +228,28 @@ async def test_create_index_success(self, mock_valkey_client): @pytest.mark.asyncio async def test_create_index_already_exists(self, mock_valkey_client): """Test that existing index is handled gracefully.""" - service = ValkeyMemoryService(client=mock_valkey_client) + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + ) with patch("glide.ft.create", new_callable=AsyncMock) as mock_create: mock_create.side_effect = Exception("Index already exists") await service.create_index() - assert service._index_created is True @pytest.mark.asyncio async def test_create_index_unexpected_error(self, mock_valkey_client): """Test that unexpected errors are raised.""" - service = ValkeyMemoryService(client=mock_valkey_client) + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + ) with patch("glide.ft.create", new_callable=AsyncMock) as mock_create: mock_create.side_effect = Exception("Connection refused") with pytest.raises(Exception, match="Connection refused"): await service.create_index() - assert service._index_created is False @@ -221,7 +264,7 @@ async def test_add_session_success(self, memory_service, mock_valkey_client): # Should make 2 hset calls (one per valid event with text) assert mock_valkey_client.hset.call_count == 2 - # Check first call stores correct fields + # Check first call stores correct fields including embedding first_call = mock_valkey_client.hset.call_args_list[0] key = first_call[0][0] assert key.startswith("adk:memory:") @@ -230,8 +273,8 @@ async def test_add_session_success(self, memory_service, mock_valkey_client): assert fields["author"] == "user" assert fields["app_name"] == MOCK_APP_NAME assert fields["user_id"] == MOCK_USER_ID - assert fields["session_id"] == MOCK_SESSION_ID - assert fields["event_id"] == "event-1" + assert "embedding" in fields + assert isinstance(fields["embedding"], bytes) @pytest.mark.asyncio async def test_add_session_filters_empty_events( @@ -248,7 +291,6 @@ async def test_add_session_with_ttl( """Test that TTL is set when configured.""" await memory_service_with_config.add_session_to_memory(MOCK_SESSION) - # Should set TTL on each hash key assert mock_valkey_client.expire.call_count == 2 expire_call = mock_valkey_client.expire.call_args_list[0] assert expire_call[0][1] == 3600 @@ -262,13 +304,29 @@ async def test_add_session_no_ttl_by_default( mock_valkey_client.expire.assert_not_called() @pytest.mark.asyncio - async def test_add_session_error_handling( + async def test_add_session_embedding_error(self, mock_valkey_client): + """Test handling of embedding function failure.""" + + async def _failing_embed(texts): + raise RuntimeError("Embedding service unavailable") + + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_failing_embed, + ) + service._index_created = True + + # Should not raise, just log error + await service.add_session_to_memory(MOCK_SESSION) + mock_valkey_client.hset.assert_not_called() + + @pytest.mark.asyncio + async def test_add_session_hset_error( self, memory_service, mock_valkey_client ): - """Test error handling during memory addition.""" + """Test error handling during hset.""" mock_valkey_client.hset.side_effect = Exception("Connection error") - # Should not raise exception, just log error await memory_service.add_session_to_memory(MOCK_SESSION) assert mock_valkey_client.hset.call_count == 2 @@ -286,17 +344,66 @@ async def test_add_session_custom_key_prefix( @pytest.mark.asyncio async def test_add_session_creates_index_if_needed(self, mock_valkey_client): """Test that create_index is called if not yet created.""" - service = ValkeyMemoryService(client=mock_valkey_client) - assert service._index_created is False + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + ) with patch("glide.ft.create", new_callable=AsyncMock) as mock_create: mock_create.return_value = "OK" await service.add_session_to_memory(MOCK_SESSION) - mock_create.assert_called_once() assert service._index_created is True +class TestValkeyMemoryServiceAddEvents: + """Tests for add_events_to_memory.""" + + @pytest.mark.asyncio + async def test_add_events_success(self, memory_service, mock_valkey_client): + """Test incremental event ingestion.""" + events = [ + Event( + id="ev-1", + invocation_id="inv-1", + author="user", + timestamp=100, + content=types.Content(parts=[types.Part(text="Hello world")]), + ), + ] + + await memory_service.add_events_to_memory( + app_name="myapp", + user_id="user1", + events=events, + session_id="sess-1", + ) + + assert mock_valkey_client.hset.call_count == 1 + fields = mock_valkey_client.hset.call_args_list[0][0][1] + assert fields["content"] == "Hello world" + assert fields["app_name"] == "myapp" + assert fields["user_id"] == "user1" + assert fields["session_id"] == "sess-1" + + @pytest.mark.asyncio + async def test_add_events_filters_empty( + self, memory_service, mock_valkey_client + ): + """Test that empty events are skipped.""" + events = [ + Event(id="ev-empty", invocation_id="inv-1", author="user"), + ] + + await memory_service.add_events_to_memory( + app_name="myapp", + user_id="user1", + events=events, + ) + + mock_valkey_client.hset.assert_not_called() + + class TestValkeyMemoryServiceSearch: """Tests for search_memory.""" @@ -304,7 +411,7 @@ class TestValkeyMemoryServiceSearch: async def test_search_memory_success( self, memory_service, mock_valkey_client ): - """Test successful memory search using FT.SEARCH.""" + """Test successful memory search using KNN.""" search_result = [ 2, { @@ -312,11 +419,13 @@ async def test_search_memory_success( b"content": b"I love Python programming", b"author": b"user", b"timestamp": b"12345", + b"__embedding_score": b"0.15", }, b"adk:memory:def456": { b"content": b"Python has great libraries", b"author": b"model", b"timestamp": b"12346", + b"__embedding_score": b"0.25", }, }, ] @@ -336,29 +445,12 @@ async def test_search_memory_success( == "I love Python programming" ) assert result.memories[0].author == "user" - assert result.memories[0].timestamp == "12345" - - @pytest.mark.asyncio - async def test_search_memory_no_results( - self, memory_service, mock_valkey_client - ): - """Test search with no matching memories.""" - with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: - mock_search.return_value = [0, {}] - - result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query="Rust language", - ) - - assert len(result.memories) == 0 @pytest.mark.asyncio async def test_search_memory_empty_result( self, memory_service, mock_valkey_client ): - """Test search when FT.SEARCH returns empty.""" + """Test search when no results.""" with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: mock_search.return_value = [0, {}] @@ -367,14 +459,13 @@ async def test_search_memory_empty_result( user_id=MOCK_USER_ID, query="anything", ) - assert len(result.memories) == 0 @pytest.mark.asyncio - async def test_search_memory_uses_correct_query( + async def test_search_memory_uses_knn_query( self, memory_service, mock_valkey_client ): - """Test that search builds correct FT.SEARCH query.""" + """Test that search builds a KNN query with TAG filters.""" with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: mock_search.return_value = [0, {}] @@ -388,59 +479,47 @@ async def test_search_memory_uses_correct_query( query_str = call_args[0][2] assert "my\\-app" in query_str assert "user\\-123" in query_str - assert "test query" in query_str + assert "KNN" in query_str + assert "@embedding" in query_str @pytest.mark.asyncio - async def test_search_memory_respects_top_k( - self, memory_service_with_config, mock_valkey_client + async def test_search_memory_passes_query_vec( + self, memory_service, mock_valkey_client ): - """Test that search uses search_top_k in LIMIT.""" + """Test that query embedding is passed as params.""" with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: mock_search.return_value = [0, {}] - await memory_service_with_config.search_memory( + await memory_service.search_memory( app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, - query="Python", + query="test", ) call_args = mock_search.call_args options = call_args[0][3] - assert options.limit is not None - - @pytest.mark.asyncio - async def test_search_memory_error_handling( - self, memory_service, mock_valkey_client - ): - """Test graceful error handling during search.""" - with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: - mock_search.side_effect = Exception("Connection error") - - result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query="test", - ) - - assert len(result.memories) == 0 + assert "query_vec" in options.params + assert isinstance(options.params["query_vec"], bytes) @pytest.mark.asyncio - async def test_search_memory_handles_missing_fields( - self, memory_service, mock_valkey_client + async def test_search_memory_distance_threshold( + self, memory_service_with_config, mock_valkey_client ): - """Test that entries with missing content are skipped.""" + """Test that distance threshold filters results.""" search_result = [ 2, { - b"adk:memory:abc123": { - b"content": b"", + b"adk:memory:close": { + b"content": b"Close match", b"author": b"user", - b"timestamp": b"12345", + b"timestamp": b"1", + b"__embedding_score": b"0.3", }, - b"adk:memory:def456": { - b"content": b"Valid memory", - b"author": b"model", - b"timestamp": b"12346", + b"adk:memory:far": { + b"content": b"Far match", + b"author": b"user", + b"timestamp": b"2", + b"__embedding_score": b"0.9", }, }, ] @@ -448,60 +527,68 @@ async def test_search_memory_handles_missing_fields( with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: mock_search.return_value = search_result - result = await memory_service.search_memory( + # Threshold is 0.5, so only "Close match" (0.3) should pass + result = await memory_service_with_config.search_memory( app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="test", ) assert len(result.memories) == 1 - assert result.memories[0].content.parts[0].text == "Valid memory" + assert result.memories[0].content.parts[0].text == "Close match" @pytest.mark.asyncio - async def test_search_memory_creates_index_if_needed( - self, mock_valkey_client - ): - """Test that search creates index if not yet created.""" - service = ValkeyMemoryService(client=mock_valkey_client) + async def test_search_memory_embedding_error(self, mock_valkey_client): + """Test graceful handling of embedding failure during search.""" - with ( - patch("glide.ft.create", new_callable=AsyncMock) as mock_create, - patch("glide.ft.search", new_callable=AsyncMock) as mock_search, - ): - mock_create.return_value = "OK" - mock_search.return_value = [0, {}] + async def _failing_embed(texts): + raise RuntimeError("Embedding service down") - await service.search_memory( + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_failing_embed, + ) + service._index_created = True + + result = await service.search_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + query="test", + ) + assert len(result.memories) == 0 + + @pytest.mark.asyncio + async def test_search_memory_ft_search_error( + self, memory_service, mock_valkey_client + ): + """Test graceful handling of FT.SEARCH failure.""" + with patch("glide.ft.search", new_callable=AsyncMock) as mock_search: + mock_search.side_effect = Exception("Connection error") + + result = await memory_service.search_memory( app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="test", ) - - mock_create.assert_called_once() + assert len(result.memories) == 0 class TestValkeyMemoryServiceBuildQuery: - """Tests for _build_search_query.""" + """Tests for _build_knn_query.""" - def test_basic_query(self, memory_service): - """Test basic query construction.""" - query = memory_service._build_search_query("myapp", "user1", "hello world") + def test_basic_knn_query(self, memory_service): + """Test KNN query construction.""" + query = memory_service._build_knn_query("myapp", "user1", 10) assert "@app_name:{myapp}" in query assert "@user_id:{user1}" in query - assert "hello world" in query + assert "KNN 10 @embedding" in query def test_hyphenated_values(self, memory_service): """Test escaping of hyphens in TAG values.""" - query = memory_service._build_search_query("my-app", "user-1", "test") + query = memory_service._build_knn_query("my-app", "user-1", 5) assert "my\\-app" in query assert "user\\-1" in query - def test_special_chars_in_query(self, memory_service): - """Test escaping of special search chars in query text.""" - query = memory_service._build_search_query("app", "user", "hello @world") - # @ should be escaped in the query text - assert "\\@world" in query - class TestValkeyMemoryServiceClose: """Tests for close method.""" From b2754fc48e0fd080af96b85c89c2ffda90459deb Mon Sep 17 00:00:00 2001 From: Daria Korenieva Date: Tue, 2 Jun 2026 12:04:58 -0700 Subject: [PATCH 4/7] fix: improve ValkeyMemoryService TAG escaping, type safety, and validation - Fix incomplete TAG value escaping in _build_knn_query: now escapes all Valkey Search metacharacters (dots, colons, @, spaces, etc.), not just hyphens. This prevents query injection and ensures correct scoping when app_name/user_id contain special characters. - Add _escape_tag_value() static helper with full Valkey Search spec coverage. - Add type annotation for client parameter (Union[GlideClient, GlideClusterClient]). - Add Pydantic field_validator for distance_metric to reject invalid values at config time instead of silently falling back to COSINE. - Fix TTL check to use 'is not None' instead of truthiness for Optional[int]. - Fix create_index docstring (removed incorrect 'timestamp: NUMERIC field'). - Add unit tests: special char escaping (dots, colons, @, spaces), distance_metric validation, _escape_tag_value coverage. Ref: AEA-497 Signed-off-by: Daria Korenieva --- .../memory/valkey_memory_service.py | 46 +++++++++-- .../memory/test_valkey_memory_service.py | 81 +++++++++++++++++++ 2 files changed, 121 insertions(+), 6 deletions(-) diff --git a/src/google/adk_community/memory/valkey_memory_service.py b/src/google/adk_community/memory/valkey_memory_service.py index 5034a88e..50122117 100644 --- a/src/google/adk_community/memory/valkey_memory_service.py +++ b/src/google/adk_community/memory/valkey_memory_service.py @@ -28,6 +28,7 @@ import time from typing import Optional from typing import TYPE_CHECKING +from typing import Union import uuid from google.adk.memory.base_memory_service import BaseMemoryService @@ -41,6 +42,8 @@ from .utils import extract_text_from_event if TYPE_CHECKING: + from glide import GlideClient + from glide import GlideClusterClient from google.adk.sessions.session import Session logger = logging.getLogger("google_adk." + __name__) @@ -76,6 +79,20 @@ class ValkeyMemoryServiceConfig(BaseModel): distance_metric: str = Field(default="COSINE") ttl_seconds: Optional[int] = Field(default=None, ge=1) + @classmethod + def _validate_distance_metric(cls, v): + """Validate distance_metric is one of the allowed values.""" + allowed = {"COSINE", "L2", "IP"} + if v.upper() not in allowed: + raise ValueError(f"distance_metric must be one of {allowed}, got '{v}'") + return v.upper() + + from pydantic import field_validator + + _check_distance_metric = field_validator("distance_metric")( + _validate_distance_metric + ) + class ValkeyMemoryService(BaseMemoryService): """Memory service using Valkey Search module with vector similarity. @@ -115,7 +132,7 @@ async def my_embed_fn(texts: list[str]) -> list[list[float]]: def __init__( self, - client, + client: Union["GlideClient", "GlideClusterClient"], embedding_function: EmbeddingFunction, config: Optional[ValkeyMemoryServiceConfig] = None, ): @@ -156,7 +173,6 @@ async def create_index(self): - app_name: TAG field for filtering by application - user_id: TAG field for filtering by user - author: TAG field for filtering by author - - timestamp: NUMERIC field for sorting This method is idempotent — if the index already exists, it will log a debug message and return without error. @@ -295,7 +311,7 @@ async def add_session_to_memory(self, session: Session): memories_added += 1 logger.debug("Added memory for event %s at key %s", event.id, hash_key) - if self._config.ttl_seconds: + if self._config.ttl_seconds is not None: await self._client.expire(hash_key, self._config.ttl_seconds) except Exception as e: logger.error("Failed to add memory for event %s: %s", event.id, e) @@ -372,13 +388,31 @@ async def add_events_to_memory( await self._client.hset(hash_key, field_values) memories_added += 1 - if self._config.ttl_seconds: + if self._config.ttl_seconds is not None: await self._client.expire(hash_key, self._config.ttl_seconds) except Exception as e: logger.error("Failed to add memory for event %s: %s", event.id, e) logger.info("Added %d memories via add_events_to_memory", memories_added) + # Characters that must be escaped in Valkey Search TAG field values. + _TAG_SPECIAL_CHARS = set(r',.<>{}[]"' + r"':;!@#$%^&*()-+=~|/\\ ") + + @staticmethod + def _escape_tag_value(value: str) -> str: + """Escape special characters for Valkey Search TAG field queries. + + Per the Valkey Search query syntax, TAG values must have + metacharacters escaped with a backslash. + """ + escaped = [] + for ch in value: + if ch in ValkeyMemoryService._TAG_SPECIAL_CHARS: + escaped.append(f"\\{ch}") + else: + escaped.append(ch) + return "".join(escaped) + def _build_knn_query(self, app_name: str, user_id: str, top_k: int) -> str: """Build a KNN search query with TAG pre-filters. @@ -390,8 +424,8 @@ def _build_knn_query(self, app_name: str, user_id: str, top_k: int) -> str: Returns: A Valkey Search KNN query string. """ - escaped_app = app_name.replace("-", "\\-") - escaped_user = user_id.replace("-", "\\-") + escaped_app = self._escape_tag_value(app_name) + escaped_user = self._escape_tag_value(user_id) # KNN query with pre-filter: filter first, then KNN on results return ( diff --git a/tests/unittests/memory/test_valkey_memory_service.py b/tests/unittests/memory/test_valkey_memory_service.py index 9cb63b80..ea3d5952 100644 --- a/tests/unittests/memory/test_valkey_memory_service.py +++ b/tests/unittests/memory/test_valkey_memory_service.py @@ -172,6 +172,25 @@ def test_config_validation_top_k(self): with pytest.raises(Exception): ValkeyMemoryServiceConfig(similarity_top_k=1001) + def test_config_validation_distance_metric(self): + """Test distance_metric validation rejects invalid values.""" + with pytest.raises(Exception): + ValkeyMemoryServiceConfig(distance_metric="HAMMING") + + with pytest.raises(Exception): + ValkeyMemoryServiceConfig(distance_metric="invalid") + + def test_config_distance_metric_case_insensitive(self): + """Test that distance_metric is case-insensitive and normalized.""" + config = ValkeyMemoryServiceConfig(distance_metric="cosine") + assert config.distance_metric == "COSINE" + + config = ValkeyMemoryServiceConfig(distance_metric="l2") + assert config.distance_metric == "L2" + + config = ValkeyMemoryServiceConfig(distance_metric="ip") + assert config.distance_metric == "IP" + class TestValkeyMemoryServiceInit: """Tests for ValkeyMemoryService initialization.""" @@ -589,6 +608,30 @@ def test_hyphenated_values(self, memory_service): assert "my\\-app" in query assert "user\\-1" in query + def test_special_characters_dots(self, memory_service): + """Test escaping of dots in TAG values.""" + query = memory_service._build_knn_query("com.example.app", "user.1", 10) + assert "com\\.example\\.app" in query + assert "user\\.1" in query + + def test_special_characters_colons(self, memory_service): + """Test escaping of colons in TAG values.""" + query = memory_service._build_knn_query("app:v2", "user:123", 10) + assert "app\\:v2" in query + assert "user\\:123" in query + + def test_special_characters_at_sign(self, memory_service): + """Test escaping of @ sign in TAG values.""" + query = memory_service._build_knn_query("app@org", "user@domain", 10) + assert "app\\@org" in query + assert "user\\@domain" in query + + def test_special_characters_spaces(self, memory_service): + """Test escaping of spaces in TAG values.""" + query = memory_service._build_knn_query("my app", "user 1", 10) + assert "my\\ app" in query + assert "user\\ 1" in query + class TestValkeyMemoryServiceClose: """Tests for close method.""" @@ -600,3 +643,41 @@ async def test_close_does_not_close_client( """Test that close does not close the underlying client.""" await memory_service.close() mock_valkey_client.close.assert_not_called() + + +class TestValkeyMemoryServiceEscapeTagValue: + """Tests for _escape_tag_value static method.""" + + def test_no_special_characters(self): + """Test that plain values are unchanged.""" + assert ValkeyMemoryService._escape_tag_value("myapp") == "myapp" + + def test_hyphen_escaped(self): + """Test hyphen escaping.""" + assert ValkeyMemoryService._escape_tag_value("my-app") == "my\\-app" + + def test_dot_escaped(self): + """Test dot escaping.""" + assert ( + ValkeyMemoryService._escape_tag_value("com.example.app") + == "com\\.example\\.app" + ) + + def test_colon_escaped(self): + """Test colon escaping.""" + assert ValkeyMemoryService._escape_tag_value("app:v2") == "app\\:v2" + + def test_at_sign_escaped(self): + """Test @ sign escaping.""" + assert ( + ValkeyMemoryService._escape_tag_value("user@domain") == "user\\@domain" + ) + + def test_space_escaped(self): + """Test space escaping.""" + assert ValkeyMemoryService._escape_tag_value("my app") == "my\\ app" + + def test_multiple_special_chars(self): + """Test multiple special characters in one value.""" + result = ValkeyMemoryService._escape_tag_value("a-b.c:d@e") + assert result == "a\\-b\\.c\\:d\\@e" From 3d6fac0f8fc0e94b05ab2ef1e7f59d1197b035ba Mon Sep 17 00:00:00 2001 From: Daria Korenieva Date: Wed, 3 Jun 2026 11:35:28 -0700 Subject: [PATCH 5/7] docs: emphasize client_name for observability in examples Add comments in docstring and README highlighting that client_name should be set on GlideClientConfiguration for visibility in CLIENT LIST, monitoring dashboards, and CloudWatch metrics. Signed-off-by: Daria Korenieva --- src/google/adk_community/memory/README.md | 4 +++- src/google/adk_community/memory/valkey_memory_service.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/google/adk_community/memory/README.md b/src/google/adk_community/memory/README.md index 03792a6c..25f4bd45 100644 --- a/src/google/adk_community/memory/README.md +++ b/src/google/adk_community/memory/README.md @@ -42,9 +42,11 @@ from glide import GlideClient, GlideClientConfiguration, NodeAddress from google.adk_community.memory import ValkeyMemoryService, ValkeyMemoryServiceConfig # 1. Create a valkey-glide client +# IMPORTANT: Set client_name for observability — it appears in CLIENT LIST, +# monitoring dashboards, and CloudWatch metrics. config = GlideClientConfiguration( addresses=[NodeAddress(host="localhost", port=6379)], - client_name="my_adk_app", + client_name="my_adk_memory_client", ) client = await GlideClient.create(config) diff --git a/src/google/adk_community/memory/valkey_memory_service.py b/src/google/adk_community/memory/valkey_memory_service.py index 50122117..7a0bf7f5 100644 --- a/src/google/adk_community/memory/valkey_memory_service.py +++ b/src/google/adk_community/memory/valkey_memory_service.py @@ -112,6 +112,8 @@ class ValkeyMemoryService(BaseMemoryService): from glide import GlideClientConfiguration, NodeAddress, GlideClient + # IMPORTANT: Set client_name for observability in CLIENT LIST, + # monitoring dashboards, and CloudWatch metrics. config = GlideClientConfiguration( addresses=[NodeAddress(host="localhost", port=6379)], client_name="adk_memory_client", From a13d14144a207edc575714b365351abfe1fa007c Mon Sep 17 00:00:00 2001 From: Daria Korenieva Date: Thu, 4 Jun 2026 08:46:33 -0700 Subject: [PATCH 6/7] Address review feedback: TAG escape, Batch pipelining, race condition, DRY refactor - Fix tenant isolation bypass: add '?' to _TAG_SPECIAL_CHARS escape set (single-char wildcard glob in Valkey Search TAG queries) - Use Batch pipelining for hset/expire calls (1 round trip vs 2N) - Add asyncio.Lock with double-check locking in _ensure_index() to prevent redundant FT.CREATE calls under concurrent access - Extract shared _ingest_events() method to DRY up add_session_to_memory and add_events_to_memory - Update unit tests for Batch-based approach; add tests for ? escaping and _ensure_index concurrency safety Signed-off-by: Daria Korenieva --- .../memory/valkey_memory_service.py | 141 ++++++------- .../memory/test_valkey_memory_service.py | 191 +++++++++++++----- 2 files changed, 208 insertions(+), 124 deletions(-) diff --git a/src/google/adk_community/memory/valkey_memory_service.py b/src/google/adk_community/memory/valkey_memory_service.py index 7a0bf7f5..ab82f156 100644 --- a/src/google/adk_community/memory/valkey_memory_service.py +++ b/src/google/adk_community/memory/valkey_memory_service.py @@ -20,6 +20,7 @@ from __future__ import annotations +import asyncio from collections.abc import Awaitable from collections.abc import Callable from collections.abc import Sequence @@ -165,6 +166,7 @@ def __init__( self._embedding_function = embedding_function self._config = config or ValkeyMemoryServiceConfig() self._index_created = False + self._index_lock = asyncio.Lock() async def create_index(self): """Create the Valkey Search index if it does not already exist. @@ -245,6 +247,14 @@ def _memory_hash_key(self) -> str: unique_id = uuid.uuid4().hex[:12] return f"{self._config.key_prefix}:{unique_id}" + async def _ensure_index(self): + """Ensure the search index exists, using double-check locking.""" + if self._index_created: + return + async with self._index_lock: + if not self._index_created: + await self.create_index() + @staticmethod def _vector_to_bytes(vector: list[float]) -> bytes: """Convert a list of floats to a binary blob for Valkey storage.""" @@ -258,67 +268,13 @@ async def add_session_to_memory(self, session: Session): configured embedding function, and stores each event as a Valkey Hash with the embedding vector for later similarity search. """ - if not self._index_created: - await self.create_index() - - # Collect texts and their corresponding events - texts = [] - valid_events = [] - for event in session.events: - content_text = extract_text_from_event(event) - if content_text: - texts.append(content_text) - valid_events.append(event) - - if not texts: - logger.debug("No text events to add from session %s", session.id) - return - - # Generate embeddings for all texts in one batch - try: - embeddings = await self._embedding_function(texts) - except Exception as e: - logger.error( - "Failed to generate embeddings for session %s: %s", - session.id, - e, - ) - return - - if len(embeddings) != len(texts): - logger.error( - "Embedding function returned %d vectors for %d texts", - len(embeddings), - len(texts), - ) - return - - memories_added = 0 - for event, content_text, embedding in zip(valid_events, texts, embeddings): - hash_key = self._memory_hash_key() - field_values = { - "content": content_text, - "author": event.author or "", - "timestamp": str(event.timestamp) if event.timestamp else "0", - "session_id": session.id, - "event_id": event.id or "", - "app_name": session.app_name, - "user_id": session.user_id, - "created_at": str(time.time()), - "embedding": self._vector_to_bytes(embedding), - } - - try: - await self._client.hset(hash_key, field_values) - memories_added += 1 - logger.debug("Added memory for event %s at key %s", event.id, hash_key) - - if self._config.ttl_seconds is not None: - await self._client.expire(hash_key, self._config.ttl_seconds) - except Exception as e: - logger.error("Failed to add memory for event %s: %s", event.id, e) - - logger.info("Added %d memories from session %s", memories_added, session.id) + await self._ensure_index() + await self._ingest_events( + events=session.events, + app_name=session.app_name, + user_id=session.user_id, + session_id=session.id, + ) @override async def add_events_to_memory( @@ -343,9 +299,36 @@ async def add_events_to_memory( session_id: Optional session ID for partitioning. custom_metadata: Optional metadata (unused currently). """ - if not self._index_created: - await self.create_index() + await self._ensure_index() + await self._ingest_events( + events=events, + app_name=app_name, + user_id=user_id, + session_id=session_id or "", + ) + async def _ingest_events( + self, + events: Sequence, + app_name: str, + user_id: str, + session_id: str, + ) -> None: + """Shared ingestion logic for add_session_to_memory and add_events_to_memory. + + Extracts text from events, generates embeddings in batch, and + stores each event as a Valkey Hash using pipelined Batch commands + for efficiency. + + Args: + events: The events to ingest. + app_name: The application name for memory scope. + user_id: The user ID for memory scope. + session_id: The session ID for partitioning. + """ + from glide import Batch + + # Collect texts and their corresponding events texts = [] valid_events = [] for event in events: @@ -355,8 +338,10 @@ async def add_events_to_memory( valid_events.append(event) if not texts: + logger.debug("No text events to ingest") return + # Generate embeddings for all texts in one batch try: embeddings = await self._embedding_function(texts) except Exception as e: @@ -371,14 +356,17 @@ async def add_events_to_memory( ) return - memories_added = 0 + # Build a Batch to pipeline all hset + expire calls + batch = Batch(is_atomic=False) + hash_keys = [] for event, content_text, embedding in zip(valid_events, texts, embeddings): hash_key = self._memory_hash_key() + hash_keys.append(hash_key) field_values = { "content": content_text, "author": event.author or "", "timestamp": str(event.timestamp) if event.timestamp else "0", - "session_id": session_id or "", + "session_id": session_id, "event_id": event.id or "", "app_name": app_name, "user_id": user_id, @@ -386,19 +374,19 @@ async def add_events_to_memory( "embedding": self._vector_to_bytes(embedding), } - try: - await self._client.hset(hash_key, field_values) - memories_added += 1 + batch.hset(hash_key, field_values) + if self._config.ttl_seconds is not None: + batch.expire(hash_key, self._config.ttl_seconds) - if self._config.ttl_seconds is not None: - await self._client.expire(hash_key, self._config.ttl_seconds) - except Exception as e: - logger.error("Failed to add memory for event %s: %s", event.id, e) - - logger.info("Added %d memories via add_events_to_memory", memories_added) + try: + await self._client.exec(batch, raise_on_error=True) + logger.info("Added %d memories via batch pipeline", len(hash_keys)) + except Exception as e: + logger.error("Failed to execute batch pipeline: %s", e) # Characters that must be escaped in Valkey Search TAG field values. - _TAG_SPECIAL_CHARS = set(r',.<>{}[]"' + r"':;!@#$%^&*()-+=~|/\\ ") + # Includes '?' which is a single-character wildcard glob in TAG queries. + _TAG_SPECIAL_CHARS = set(r',.<>{}[]"' + r"':;!@#$%^&*()-+=~|/\\ ?") @staticmethod def _escape_tag_value(value: str) -> str: @@ -459,8 +447,7 @@ async def search_memory( from glide import ft from glide import FtSearchOptions - if not self._index_created: - await self.create_index() + await self._ensure_index() # Generate embedding for the query try: diff --git a/tests/unittests/memory/test_valkey_memory_service.py b/tests/unittests/memory/test_valkey_memory_service.py index ea3d5952..318e8f09 100644 --- a/tests/unittests/memory/test_valkey_memory_service.py +++ b/tests/unittests/memory/test_valkey_memory_service.py @@ -97,6 +97,7 @@ def mock_valkey_client(): client = AsyncMock() client.hset = AsyncMock(return_value=1) client.expire = AsyncMock(return_value=True) + client.exec = AsyncMock(return_value=[]) return client @@ -278,22 +279,19 @@ class TestValkeyMemoryServiceAddSession: @pytest.mark.asyncio async def test_add_session_success(self, memory_service, mock_valkey_client): """Test successful addition of session memories.""" - await memory_service.add_session_to_memory(MOCK_SESSION) - - # Should make 2 hset calls (one per valid event with text) - assert mock_valkey_client.hset.call_count == 2 - - # Check first call stores correct fields including embedding - first_call = mock_valkey_client.hset.call_args_list[0] - key = first_call[0][0] - assert key.startswith("adk:memory:") - fields = first_call[0][1] - assert fields["content"] == "Hello, I like Python." - assert fields["author"] == "user" - assert fields["app_name"] == MOCK_APP_NAME - assert fields["user_id"] == MOCK_USER_ID - assert "embedding" in fields - assert isinstance(fields["embedding"], bytes) + with patch("glide.Batch") as MockBatch: + mock_batch_instance = MockBatch.return_value + mock_batch_instance.hset = lambda *args, **kwargs: None + mock_batch_instance.expire = lambda *args, **kwargs: None + + await memory_service.add_session_to_memory(MOCK_SESSION) + + # Should call exec once with the batch + mock_valkey_client.exec.assert_called_once_with( + mock_batch_instance, raise_on_error=True + ) + # Batch should be created as non-atomic + MockBatch.assert_called_once_with(is_atomic=False) @pytest.mark.asyncio async def test_add_session_filters_empty_events( @@ -301,26 +299,46 @@ async def test_add_session_filters_empty_events( ): """Test that events without text content are filtered out.""" await memory_service.add_session_to_memory(MOCK_SESSION_WITH_EMPTY_EVENTS) - assert mock_valkey_client.hset.call_count == 0 + mock_valkey_client.exec.assert_not_called() @pytest.mark.asyncio async def test_add_session_with_ttl( self, memory_service_with_config, mock_valkey_client ): """Test that TTL is set when configured.""" - await memory_service_with_config.add_session_to_memory(MOCK_SESSION) + with patch("glide.Batch") as MockBatch: + hset_calls = [] + expire_calls = [] + mock_batch_instance = MockBatch.return_value + mock_batch_instance.hset = lambda *args, **kwargs: hset_calls.append(args) + mock_batch_instance.expire = lambda *args, **kwargs: expire_calls.append( + args + ) + + await memory_service_with_config.add_session_to_memory(MOCK_SESSION) - assert mock_valkey_client.expire.call_count == 2 - expire_call = mock_valkey_client.expire.call_args_list[0] - assert expire_call[0][1] == 3600 + # 2 valid events -> 2 hset + 2 expire calls on the batch + assert len(hset_calls) == 2 + assert len(expire_calls) == 2 + # TTL should be 3600 + assert expire_calls[0][1] == 3600 @pytest.mark.asyncio async def test_add_session_no_ttl_by_default( self, memory_service, mock_valkey_client ): """Test that no TTL is set when not configured.""" - await memory_service.add_session_to_memory(MOCK_SESSION) - mock_valkey_client.expire.assert_not_called() + with patch("glide.Batch") as MockBatch: + expire_calls = [] + mock_batch_instance = MockBatch.return_value + mock_batch_instance.hset = lambda *args, **kwargs: None + mock_batch_instance.expire = lambda *args, **kwargs: expire_calls.append( + args + ) + + await memory_service.add_session_to_memory(MOCK_SESSION) + + assert len(expire_calls) == 0 @pytest.mark.asyncio async def test_add_session_embedding_error(self, mock_valkey_client): @@ -337,28 +355,39 @@ async def _failing_embed(texts): # Should not raise, just log error await service.add_session_to_memory(MOCK_SESSION) - mock_valkey_client.hset.assert_not_called() + mock_valkey_client.exec.assert_not_called() @pytest.mark.asyncio - async def test_add_session_hset_error( + async def test_add_session_batch_exec_error( self, memory_service, mock_valkey_client ): - """Test error handling during hset.""" - mock_valkey_client.hset.side_effect = Exception("Connection error") + """Test error handling during batch exec.""" + mock_valkey_client.exec.side_effect = Exception("Connection error") + + with patch("glide.Batch") as MockBatch: + mock_batch_instance = MockBatch.return_value + mock_batch_instance.hset = lambda *args, **kwargs: None + mock_batch_instance.expire = lambda *args, **kwargs: None - await memory_service.add_session_to_memory(MOCK_SESSION) - assert mock_valkey_client.hset.call_count == 2 + # Should not raise, just log error + await memory_service.add_session_to_memory(MOCK_SESSION) + mock_valkey_client.exec.assert_called_once() @pytest.mark.asyncio async def test_add_session_custom_key_prefix( self, memory_service_with_config, mock_valkey_client ): """Test that custom key prefix is used.""" - await memory_service_with_config.add_session_to_memory(MOCK_SESSION) + with patch("glide.Batch") as MockBatch: + hset_calls = [] + mock_batch_instance = MockBatch.return_value + mock_batch_instance.hset = lambda *args, **kwargs: hset_calls.append(args) + mock_batch_instance.expire = lambda *args, **kwargs: None - first_call = mock_valkey_client.hset.call_args_list[0] - key = first_call[0][0] - assert key.startswith("custom:mem:") + await memory_service_with_config.add_session_to_memory(MOCK_SESSION) + + key = hset_calls[0][0] + assert key.startswith("custom:mem:") @pytest.mark.asyncio async def test_add_session_creates_index_if_needed(self, mock_valkey_client): @@ -368,8 +397,15 @@ async def test_add_session_creates_index_if_needed(self, mock_valkey_client): embedding_function=_mock_embed_fn, ) - with patch("glide.ft.create", new_callable=AsyncMock) as mock_create: + with ( + patch("glide.ft.create", new_callable=AsyncMock) as mock_create, + patch("glide.Batch") as MockBatch, + ): mock_create.return_value = "OK" + mock_batch_instance = MockBatch.return_value + mock_batch_instance.hset = lambda *args, **kwargs: None + mock_batch_instance.expire = lambda *args, **kwargs: None + await service.add_session_to_memory(MOCK_SESSION) mock_create.assert_called_once() assert service._index_created is True @@ -391,19 +427,26 @@ async def test_add_events_success(self, memory_service, mock_valkey_client): ), ] - await memory_service.add_events_to_memory( - app_name="myapp", - user_id="user1", - events=events, - session_id="sess-1", - ) + with patch("glide.Batch") as MockBatch: + hset_calls = [] + mock_batch_instance = MockBatch.return_value + mock_batch_instance.hset = lambda *args, **kwargs: hset_calls.append(args) + mock_batch_instance.expire = lambda *args, **kwargs: None + + await memory_service.add_events_to_memory( + app_name="myapp", + user_id="user1", + events=events, + session_id="sess-1", + ) - assert mock_valkey_client.hset.call_count == 1 - fields = mock_valkey_client.hset.call_args_list[0][0][1] - assert fields["content"] == "Hello world" - assert fields["app_name"] == "myapp" - assert fields["user_id"] == "user1" - assert fields["session_id"] == "sess-1" + assert len(hset_calls) == 1 + fields = hset_calls[0][1] + assert fields["content"] == "Hello world" + assert fields["app_name"] == "myapp" + assert fields["user_id"] == "user1" + assert fields["session_id"] == "sess-1" + mock_valkey_client.exec.assert_called_once() @pytest.mark.asyncio async def test_add_events_filters_empty( @@ -420,7 +463,7 @@ async def test_add_events_filters_empty( events=events, ) - mock_valkey_client.hset.assert_not_called() + mock_valkey_client.exec.assert_not_called() class TestValkeyMemoryServiceSearch: @@ -645,6 +688,55 @@ async def test_close_does_not_close_client( mock_valkey_client.close.assert_not_called() +class TestValkeyMemoryServiceEnsureIndex: + """Tests for _ensure_index with asyncio.Lock.""" + + @pytest.mark.asyncio + async def test_ensure_index_skips_if_already_created( + self, mock_valkey_client + ): + """Test that _ensure_index does not call create_index if already done.""" + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + ) + service._index_created = True + + with patch("glide.ft.create", new_callable=AsyncMock) as mock_create: + await service._ensure_index() + mock_create.assert_not_called() + + @pytest.mark.asyncio + async def test_ensure_index_calls_create_index_once(self, mock_valkey_client): + """Test that concurrent calls to _ensure_index only create index once.""" + import asyncio + + service = ValkeyMemoryService( + client=mock_valkey_client, + embedding_function=_mock_embed_fn, + ) + + call_count = 0 + + async def mock_create(*args, **kwargs): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.01) # Simulate async work + return "OK" + + with patch("glide.ft.create", side_effect=mock_create): + # Launch multiple concurrent ensure_index calls + await asyncio.gather( + service._ensure_index(), + service._ensure_index(), + service._ensure_index(), + ) + + # Should only call create once due to lock + assert call_count == 1 + assert service._index_created is True + + class TestValkeyMemoryServiceEscapeTagValue: """Tests for _escape_tag_value static method.""" @@ -681,3 +773,8 @@ def test_multiple_special_chars(self): """Test multiple special characters in one value.""" result = ValkeyMemoryService._escape_tag_value("a-b.c:d@e") assert result == "a\\-b\\.c\\:d\\@e" + + def test_question_mark_escaped(self): + """Test question mark escaping (wildcard glob in TAG queries).""" + assert ValkeyMemoryService._escape_tag_value("app?") == "app\\?" + assert ValkeyMemoryService._escape_tag_value("a?b?c") == "a\\?b\\?c" From 9e576438db5ae17187dbfaecae3f7c8d1106bba1 Mon Sep 17 00:00:00 2001 From: Daria Korenieva Date: Fri, 5 Jun 2026 09:10:27 -0700 Subject: [PATCH 7/7] Address review feedback: fix pydantic validator, re-raise batch errors, remove dead code - Move field_validator import to top-level (was polluting class namespace) - Re-raise RuntimeError on batch exec failure instead of swallowing - Remove unreachable ternary fallback (guarded by earlier len check) - Update unit test to expect RuntimeError on batch exec failure Signed-off-by: Daria Korenieva --- .../adk_community/memory/valkey_memory_service.py | 11 ++++------- tests/unittests/memory/test_valkey_memory_service.py | 6 +++--- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/google/adk_community/memory/valkey_memory_service.py b/src/google/adk_community/memory/valkey_memory_service.py index ab82f156..6cc155a7 100644 --- a/src/google/adk_community/memory/valkey_memory_service.py +++ b/src/google/adk_community/memory/valkey_memory_service.py @@ -38,6 +38,7 @@ from google.genai import types from pydantic import BaseModel from pydantic import Field +from pydantic import field_validator from typing_extensions import override from .utils import extract_text_from_event @@ -80,6 +81,7 @@ class ValkeyMemoryServiceConfig(BaseModel): distance_metric: str = Field(default="COSINE") ttl_seconds: Optional[int] = Field(default=None, ge=1) + @field_validator("distance_metric") @classmethod def _validate_distance_metric(cls, v): """Validate distance_metric is one of the allowed values.""" @@ -88,12 +90,6 @@ def _validate_distance_metric(cls, v): raise ValueError(f"distance_metric must be one of {allowed}, got '{v}'") return v.upper() - from pydantic import field_validator - - _check_distance_metric = field_validator("distance_metric")( - _validate_distance_metric - ) - class ValkeyMemoryService(BaseMemoryService): """Memory service using Valkey Search module with vector similarity. @@ -383,6 +379,7 @@ async def _ingest_events( logger.info("Added %d memories via batch pipeline", len(hash_keys)) except Exception as e: logger.error("Failed to execute batch pipeline: %s", e) + raise RuntimeError(f"Memory ingestion failed: {e}") from e # Characters that must be escaped in Valkey Search TAG field values. # Includes '?' which is a single-character wildcard glob in TAG queries. @@ -482,7 +479,7 @@ async def search_memory( return SearchMemoryResponse(memories=[]) memories = [] - doc_map = result[1] if len(result) > 1 else {} + doc_map = result[1] for doc_id, fields in doc_map.items(): try: diff --git a/tests/unittests/memory/test_valkey_memory_service.py b/tests/unittests/memory/test_valkey_memory_service.py index 318e8f09..92bfb4e2 100644 --- a/tests/unittests/memory/test_valkey_memory_service.py +++ b/tests/unittests/memory/test_valkey_memory_service.py @@ -361,7 +361,7 @@ async def _failing_embed(texts): async def test_add_session_batch_exec_error( self, memory_service, mock_valkey_client ): - """Test error handling during batch exec.""" + """Test that batch exec failure raises RuntimeError.""" mock_valkey_client.exec.side_effect = Exception("Connection error") with patch("glide.Batch") as MockBatch: @@ -369,8 +369,8 @@ async def test_add_session_batch_exec_error( mock_batch_instance.hset = lambda *args, **kwargs: None mock_batch_instance.expire = lambda *args, **kwargs: None - # Should not raise, just log error - await memory_service.add_session_to_memory(MOCK_SESSION) + with pytest.raises(RuntimeError, match="Memory ingestion failed"): + await memory_service.add_session_to_memory(MOCK_SESSION) mock_valkey_client.exec.assert_called_once() @pytest.mark.asyncio