From bcf1deb58234c49ea7a425d84e11c54aabb5e4f8 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Fri, 2 May 2025 14:18:38 -0700 Subject: [PATCH] ADK changes PiperOrigin-RevId: 754131080 --- src/google/adk/memory/base_memory_service.py | 4 ++-- src/google/adk/memory/in_memory_memory_service.py | 4 ++-- src/google/adk/memory/vertex_ai_rag_memory_service.py | 4 ++-- src/google/adk/runners.py | 4 ++-- src/google/adk/tools/load_memory_tool.py | 6 ++++-- src/google/adk/tools/preload_memory_tool.py | 2 +- src/google/adk/tools/tool_context.py | 4 ++-- 7 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/google/adk/memory/base_memory_service.py b/src/google/adk/memory/base_memory_service.py index 93c06b4..86ceba9 100644 --- a/src/google/adk/memory/base_memory_service.py +++ b/src/google/adk/memory/base_memory_service.py @@ -51,7 +51,7 @@ class BaseMemoryService(abc.ABC): """ @abc.abstractmethod - def add_session_to_memory(self, session: Session): + async def add_session_to_memory(self, session: Session): """Adds a session to the memory service. A session may be added multiple times during its lifetime. @@ -61,7 +61,7 @@ class BaseMemoryService(abc.ABC): """ @abc.abstractmethod - def search_memory( + async def search_memory( self, *, app_name: str, user_id: str, query: str ) -> SearchMemoryResponse: """Searches for sessions that match the query. diff --git a/src/google/adk/memory/in_memory_memory_service.py b/src/google/adk/memory/in_memory_memory_service.py index 8976344..1f15486 100644 --- a/src/google/adk/memory/in_memory_memory_service.py +++ b/src/google/adk/memory/in_memory_memory_service.py @@ -29,13 +29,13 @@ class InMemoryMemoryService(BaseMemoryService): self.session_events: dict[str, list[Event]] = {} """keys are app_name/user_id/session_id""" - def add_session_to_memory(self, session: Session): + async def add_session_to_memory(self, session: Session): key = f'{session.app_name}/{session.user_id}/{session.id}' self.session_events[key] = [ event for event in session.events if event.content ] - def search_memory( + async def search_memory( self, *, app_name: str, user_id: str, query: str ) -> SearchMemoryResponse: """Prototyping purpose only.""" diff --git a/src/google/adk/memory/vertex_ai_rag_memory_service.py b/src/google/adk/memory/vertex_ai_rag_memory_service.py index 3582260..c147ae8 100644 --- a/src/google/adk/memory/vertex_ai_rag_memory_service.py +++ b/src/google/adk/memory/vertex_ai_rag_memory_service.py @@ -54,7 +54,7 @@ class VertexAiRagMemoryService(BaseMemoryService): ) @override - def add_session_to_memory(self, session: Session): + async def add_session_to_memory(self, session: Session): with tempfile.NamedTemporaryFile( mode="w", delete=False, suffix=".txt" ) as temp_file: @@ -91,7 +91,7 @@ class VertexAiRagMemoryService(BaseMemoryService): os.remove(temp_file_path) @override - def search_memory( + async def search_memory( self, *, app_name: str, user_id: str, query: str ) -> SearchMemoryResponse: """Searches for sessions that match the query using rag.retrieval_query.""" diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 9041957..15b1ee2 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -297,14 +297,14 @@ class Runner: self.session_service.append_event(session=session, event=event) yield event - def close_session(self, session: Session): + async def close_session(self, session: Session): """Closes a session and adds it to the memory service (experimental feature). Args: session: The session to close. """ if self.memory_service: - self.memory_service.add_session_to_memory(session) + await self.memory_service.add_session_to_memory(session) self.session_service.close_session(session=session) def _find_agent_to_run( diff --git a/src/google/adk/tools/load_memory_tool.py b/src/google/adk/tools/load_memory_tool.py index 7b4de48..3fe530b 100644 --- a/src/google/adk/tools/load_memory_tool.py +++ b/src/google/adk/tools/load_memory_tool.py @@ -27,7 +27,9 @@ if TYPE_CHECKING: from ..models import LlmRequest -def load_memory(query: str, tool_context: ToolContext) -> 'list[MemoryResult]': +async def load_memory( + query: str, tool_context: ToolContext +) -> 'list[MemoryResult]': """Loads the memory for the current user. Args: @@ -36,7 +38,7 @@ def load_memory(query: str, tool_context: ToolContext) -> 'list[MemoryResult]': Returns: A list of memory results. """ - response = tool_context.search_memory(query) + response = await tool_context.search_memory(query) return response.memories diff --git a/src/google/adk/tools/preload_memory_tool.py b/src/google/adk/tools/preload_memory_tool.py index ebc682d..ddefc44 100644 --- a/src/google/adk/tools/preload_memory_tool.py +++ b/src/google/adk/tools/preload_memory_tool.py @@ -45,7 +45,7 @@ class PreloadMemoryTool(BaseTool): if not parts or not parts[0].text: return query = parts[0].text - response = tool_context.search_memory(query) + response = await tool_context.search_memory(query) if not response.memories: return memory_text = '' diff --git a/src/google/adk/tools/tool_context.py b/src/google/adk/tools/tool_context.py index e2d1262..ad8fce8 100644 --- a/src/google/adk/tools/tool_context.py +++ b/src/google/adk/tools/tool_context.py @@ -79,11 +79,11 @@ class ToolContext(CallbackContext): session_id=self._invocation_context.session.id, ) - def search_memory(self, query: str) -> 'SearchMemoryResponse': + async def search_memory(self, query: str) -> SearchMemoryResponse: """Searches the memory of the current user.""" if self._invocation_context.memory_service is None: raise ValueError('Memory service is not available.') - return self._invocation_context.memory_service.search_memory( + return await self._invocation_context.memory_service.search_memory( app_name=self._invocation_context.app_name, user_id=self._invocation_context.user_id, query=query,