ADK changes

PiperOrigin-RevId: 754131080
This commit is contained in:
Shangjie Chen 2025-05-02 14:18:38 -07:00 committed by Copybara-Service
parent 879064343c
commit bcf1deb582
7 changed files with 15 additions and 13 deletions

View File

@ -51,7 +51,7 @@ class BaseMemoryService(abc.ABC):
""" """
@abc.abstractmethod @abc.abstractmethod
def add_session_to_memory(self, session: Session): async def add_session_to_memory(self, session: Session):
"""Adds a session to the memory service. """Adds a session to the memory service.
A session may be added multiple times during its lifetime. A session may be added multiple times during its lifetime.
@ -61,7 +61,7 @@ class BaseMemoryService(abc.ABC):
""" """
@abc.abstractmethod @abc.abstractmethod
def search_memory( async def search_memory(
self, *, app_name: str, user_id: str, query: str self, *, app_name: str, user_id: str, query: str
) -> SearchMemoryResponse: ) -> SearchMemoryResponse:
"""Searches for sessions that match the query. """Searches for sessions that match the query.

View File

@ -29,13 +29,13 @@ class InMemoryMemoryService(BaseMemoryService):
self.session_events: dict[str, list[Event]] = {} self.session_events: dict[str, list[Event]] = {}
"""keys are app_name/user_id/session_id""" """keys are app_name/user_id/session_id"""
def add_session_to_memory(self, session: Session): async def add_session_to_memory(self, session: Session):
key = f'{session.app_name}/{session.user_id}/{session.id}' key = f'{session.app_name}/{session.user_id}/{session.id}'
self.session_events[key] = [ self.session_events[key] = [
event for event in session.events if event.content event for event in session.events if event.content
] ]
def search_memory( async def search_memory(
self, *, app_name: str, user_id: str, query: str self, *, app_name: str, user_id: str, query: str
) -> SearchMemoryResponse: ) -> SearchMemoryResponse:
"""Prototyping purpose only.""" """Prototyping purpose only."""

View File

@ -54,7 +54,7 @@ class VertexAiRagMemoryService(BaseMemoryService):
) )
@override @override
def add_session_to_memory(self, session: Session): async def add_session_to_memory(self, session: Session):
with tempfile.NamedTemporaryFile( with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".txt" mode="w", delete=False, suffix=".txt"
) as temp_file: ) as temp_file:
@ -91,7 +91,7 @@ class VertexAiRagMemoryService(BaseMemoryService):
os.remove(temp_file_path) os.remove(temp_file_path)
@override @override
def search_memory( async def search_memory(
self, *, app_name: str, user_id: str, query: str self, *, app_name: str, user_id: str, query: str
) -> SearchMemoryResponse: ) -> SearchMemoryResponse:
"""Searches for sessions that match the query using rag.retrieval_query.""" """Searches for sessions that match the query using rag.retrieval_query."""

View File

@ -297,14 +297,14 @@ class Runner:
self.session_service.append_event(session=session, event=event) self.session_service.append_event(session=session, event=event)
yield event yield event
def close_session(self, session: Session): async def close_session(self, session: Session):
"""Closes a session and adds it to the memory service (experimental feature). """Closes a session and adds it to the memory service (experimental feature).
Args: Args:
session: The session to close. session: The session to close.
""" """
if self.memory_service: if self.memory_service:
self.memory_service.add_session_to_memory(session) await self.memory_service.add_session_to_memory(session)
self.session_service.close_session(session=session) self.session_service.close_session(session=session)
def _find_agent_to_run( def _find_agent_to_run(

View File

@ -27,7 +27,9 @@ if TYPE_CHECKING:
from ..models import LlmRequest from ..models import LlmRequest
def load_memory(query: str, tool_context: ToolContext) -> 'list[MemoryResult]': async def load_memory(
query: str, tool_context: ToolContext
) -> 'list[MemoryResult]':
"""Loads the memory for the current user. """Loads the memory for the current user.
Args: Args:
@ -36,7 +38,7 @@ def load_memory(query: str, tool_context: ToolContext) -> 'list[MemoryResult]':
Returns: Returns:
A list of memory results. A list of memory results.
""" """
response = tool_context.search_memory(query) response = await tool_context.search_memory(query)
return response.memories return response.memories

View File

@ -45,7 +45,7 @@ class PreloadMemoryTool(BaseTool):
if not parts or not parts[0].text: if not parts or not parts[0].text:
return return
query = parts[0].text query = parts[0].text
response = tool_context.search_memory(query) response = await tool_context.search_memory(query)
if not response.memories: if not response.memories:
return return
memory_text = '' memory_text = ''

View File

@ -79,11 +79,11 @@ class ToolContext(CallbackContext):
session_id=self._invocation_context.session.id, session_id=self._invocation_context.session.id,
) )
def search_memory(self, query: str) -> 'SearchMemoryResponse': async def search_memory(self, query: str) -> SearchMemoryResponse:
"""Searches the memory of the current user.""" """Searches the memory of the current user."""
if self._invocation_context.memory_service is None: if self._invocation_context.memory_service is None:
raise ValueError('Memory service is not available.') raise ValueError('Memory service is not available.')
return self._invocation_context.memory_service.search_memory( return await self._invocation_context.memory_service.search_memory(
app_name=self._invocation_context.app_name, app_name=self._invocation_context.app_name,
user_id=self._invocation_context.user_id, user_id=self._invocation_context.user_id,
query=query, query=query,