diff --git a/contributing/samples/memory/agent.py b/contributing/samples/memory/agent.py index 06c3202..3f41596 100755 --- a/contributing/samples/memory/agent.py +++ b/contributing/samples/memory/agent.py @@ -12,20 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random + +from datetime import datetime from google.adk import Agent +from google.adk.agents.callback_context import CallbackContext from google.adk.tools.load_memory_tool import load_memory_tool from google.adk.tools.preload_memory_tool import preload_memory_tool -from google.genai import types + + +def update_current_time(callback_context: CallbackContext): + callback_context.state['_time'] = datetime.now().isoformat() root_agent = Agent( - model='gemini-2.0-flash-exp', + model='gemini-2.0-flash-001', name='memory_agent', description='agent that have access to memory tools.', - instruction=""" - You are an agent that help user answer questions. - """, - tools=[load_memory_tool, preload_memory_tool], + before_agent_callback=update_current_time, + instruction="""\ +You are an agent that help user answer questions. + +Current time: {_time} +""", + tools=[ + load_memory_tool, + preload_memory_tool, + ], ) diff --git a/contributing/samples/memory/asyncio_run.py b/contributing/samples/memory/asyncio_run.py new file mode 100755 index 0000000..9cdea4e --- /dev/null +++ b/contributing/samples/memory/asyncio_run.py @@ -0,0 +1,111 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from datetime import datetime +from datetime import timedelta +from typing import cast +import warnings + +import agent +from dotenv import load_dotenv +from google.adk.cli.utils import logs +from google.adk.runners import InMemoryRunner +from google.adk.sessions import Session +from google.genai import types + +load_dotenv(override=True) +warnings.filterwarnings('ignore', category=UserWarning) +logs.log_to_tmp_folder() + + +async def main(): + app_name = 'my_app' + user_id_1 = 'user1' + runner = InMemoryRunner( + app_name=app_name, + agent=agent.root_agent, + ) + + async def run_prompt(session: Session, new_message: str) -> Session: + content = types.Content( + role='user', parts=[types.Part.from_text(text=new_message)] + ) + print('** User says:', content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ): + if not event.content or not event.content.parts: + continue + if event.content.parts[0].text: + print(f'** {event.author}: {event.content.parts[0].text}') + elif event.content.parts[0].function_call: + print( + f'** {event.author}: fc /' + f' {event.content.parts[0].function_call.name} /' + f' {event.content.parts[0].function_call.args}\n' + ) + elif event.content.parts[0].function_response: + print( + f'** {event.author}: fr /' + f' {event.content.parts[0].function_response.name} /' + f' {event.content.parts[0].function_response.response}\n' + ) + + return cast( + Session, + runner.session_service.get_session( + app_name=app_name, user_id=user_id_1, session_id=session.id + ), + ) + + session_1 = runner.session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + + print(f'----Session to create memory: {session_1.id} ----------------------') + session_1 = await run_prompt(session_1, 'Hi') + session_1 = await run_prompt(session_1, 'My name is Jack') + session_1 = await run_prompt(session_1, 'I like badminton.') + session_1 = await run_prompt( + session_1, + f'I ate a burger on {(datetime.now() - timedelta(days=1)).date()}.', + ) + session_1 = await run_prompt( + session_1, + f'I ate a banana on {(datetime.now() - timedelta(days=2)).date()}.', + ) + print('Saving session to memory service...') + if runner.memory_service: + await runner.memory_service.add_session_to_memory(session_1) + print('-------------------------------------------------------------------') + + session_2 = runner.session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + print(f'----Session to use memory: {session_2.id} ----------------------') + session_2 = await run_prompt(session_2, 'Hi') + session_2 = await run_prompt(session_2, 'What do I like to do?') + # ** memory_agent: You like badminton. + session_2 = await run_prompt(session_2, 'When did I say that?') + # ** memory_agent: You said you liked badminton on ... + session_2 = await run_prompt(session_2, 'What did I eat yesterday?') + # ** memory_agent: You ate a burger yesterday... + print('-------------------------------------------------------------------') + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/google/adk/memory/_utils.py b/src/google/adk/memory/_utils.py new file mode 100644 index 0000000..33c5640 --- /dev/null +++ b/src/google/adk/memory/_utils.py @@ -0,0 +1,23 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from datetime import datetime + + +def format_timestamp(timestamp: float) -> str: + """Formats the timestamp of the memory entry.""" + return datetime.fromtimestamp(timestamp).isoformat() diff --git a/src/google/adk/memory/base_memory_service.py b/src/google/adk/memory/base_memory_service.py index 86ceba9..65932de 100644 --- a/src/google/adk/memory/base_memory_service.py +++ b/src/google/adk/memory/base_memory_service.py @@ -12,46 +12,44 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc + +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod +from typing import TYPE_CHECKING from pydantic import BaseModel from pydantic import Field -from ..events.event import Event -from ..sessions.session import Session +from .memory_entry import MemoryEntry - -class MemoryResult(BaseModel): - """Represents a single memory retrieval result. - - Attributes: - session_id: The session id associated with the memory. - events: A list of events in the session. - """ - - session_id: str - events: list[Event] +if TYPE_CHECKING: + from ..sessions.session import Session class SearchMemoryResponse(BaseModel): """Represents the response from a memory search. Attributes: - memories: A list of memory results matching the search query. + memories: A list of memory entries that relate to the search query. """ - memories: list[MemoryResult] = Field(default_factory=list) + memories: list[MemoryEntry] = Field(default_factory=list) -class BaseMemoryService(abc.ABC): +class BaseMemoryService(ABC): """Base class for memory services. The service provides functionalities to ingest sessions into memory so that the memory can be used for user queries. """ - @abc.abstractmethod - async def add_session_to_memory(self, session: Session): + @abstractmethod + async def add_session_to_memory( + self, + session: Session, + ): """Adds a session to the memory service. A session may be added multiple times during its lifetime. @@ -60,9 +58,13 @@ class BaseMemoryService(abc.ABC): session: The session to add. """ - @abc.abstractmethod + @abstractmethod async def search_memory( - self, *, app_name: str, user_id: str, query: str + self, + *, + app_name: str, + user_id: str, + query: str, ) -> SearchMemoryResponse: """Searches for sessions that match the query. diff --git a/src/google/adk/memory/in_memory_memory_service.py b/src/google/adk/memory/in_memory_memory_service.py index 1f15486..a49aca5 100644 --- a/src/google/adk/memory/in_memory_memory_service.py +++ b/src/google/adk/memory/in_memory_memory_service.py @@ -12,11 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..events.event import Event -from ..sessions.session import Session + +from __future__ import annotations + +import re +from typing import TYPE_CHECKING + +from typing_extensions import override + +from . import _utils from .base_memory_service import BaseMemoryService -from .base_memory_service import MemoryResult from .base_memory_service import SearchMemoryResponse +from .memory_entry import MemoryEntry + +if TYPE_CHECKING: + from ..events.event import Event + from ..sessions.session import Session + + +def _user_key(app_name: str, user_id: str): + return f'{app_name}/{user_id}' + + +def _extract_words_lower(text: str) -> set[str]: + """Extracts words from a string and converts them to lowercase.""" + return set([word.lower() for word in re.findall(r'[A-Za-z]+', text)]) class InMemoryMemoryService(BaseMemoryService): @@ -26,37 +46,49 @@ class InMemoryMemoryService(BaseMemoryService): """ def __init__(self): - self.session_events: dict[str, list[Event]] = {} - """keys are app_name/user_id/session_id""" + self._session_events: dict[str, dict[str, list[Event]]] = {} + """Keys are app_name/user_id, session_id. Values are session event lists.""" + @override async def add_session_to_memory(self, session: Session): - key = f'{session.app_name}/{session.user_id}/{session.id}' - self.session_events[key] = [ - event for event in session.events if event.content + user_key = _user_key(session.app_name, session.user_id) + self._session_events[user_key] = self._session_events.get( + _user_key(session.app_name, session.user_id), {} + ) + self._session_events[user_key][session.id] = [ + event + for event in session.events + if event.content and event.content.parts ] + @override async def search_memory( self, *, app_name: str, user_id: str, query: str ) -> SearchMemoryResponse: - """Prototyping purpose only.""" - keywords = set(query.lower().split()) + user_key = _user_key(app_name, user_id) + if user_key not in self._session_events: + return SearchMemoryResponse() + + words_in_query = set(query.lower().split()) response = SearchMemoryResponse() - for key, events in self.session_events.items(): - if not key.startswith(f'{app_name}/{user_id}/'): - continue - matched_events = [] - for event in events: + + for session_events in self._session_events[user_key].values(): + for event in session_events: if not event.content or not event.content.parts: continue - parts = event.content.parts - text = '\n'.join([part.text for part in parts if part.text]).lower() - for keyword in keywords: - if keyword in text: - matched_events.append(event) - break - if matched_events: - session_id = key.split('/')[-1] - response.memories.append( - MemoryResult(session_id=session_id, events=matched_events) + words_in_event = _extract_words_lower( + ' '.join([part.text for part in event.content.parts if part.text]) ) + if not words_in_event: + continue + + if any(query_word in words_in_event for query_word in words_in_query): + response.memories.append( + MemoryEntry( + content=event.content, + author=event.author, + timestamp=_utils.format_timestamp(event.timestamp), + ) + ) + return response diff --git a/src/google/adk/memory/vertex_ai_rag_memory_service.py b/src/google/adk/memory/vertex_ai_rag_memory_service.py index c147ae8..2322071 100644 --- a/src/google/adk/memory/vertex_ai_rag_memory_service.py +++ b/src/google/adk/memory/vertex_ai_rag_memory_service.py @@ -16,6 +16,7 @@ from collections import OrderedDict import json import os import tempfile +from typing import Optional from google.genai import types from typing_extensions import override @@ -23,9 +24,10 @@ from vertexai.preview import rag from ..events.event import Event from ..sessions.session import Session +from . import _utils from .base_memory_service import BaseMemoryService -from .base_memory_service import MemoryResult from .base_memory_service import SearchMemoryResponse +from .memory_entry import MemoryEntry class VertexAiRagMemoryService(BaseMemoryService): @@ -33,8 +35,8 @@ class VertexAiRagMemoryService(BaseMemoryService): def __init__( self, - rag_corpus: str = None, - similarity_top_k: int = None, + rag_corpus: Optional[str] = None, + similarity_top_k: Optional[int] = None, vector_distance_threshold: float = 10, ): """Initializes a VertexAiRagMemoryService. @@ -47,8 +49,10 @@ class VertexAiRagMemoryService(BaseMemoryService): vector_distance_threshold: Only returns contexts with vector distance smaller than the threshold.. """ - self.vertex_rag_store = types.VertexRagStore( - rag_resources=[rag.RagResource(rag_corpus=rag_corpus)], + self._vertex_rag_store = types.VertexRagStore( + rag_resources=[ + types.VertexRagStoreRagResource(rag_corpus=rag_corpus), + ], similarity_top_k=similarity_top_k, vector_distance_threshold=vector_distance_threshold, ) @@ -79,7 +83,11 @@ class VertexAiRagMemoryService(BaseMemoryService): output_string = "\n".join(output_lines) temp_file.write(output_string) temp_file_path = temp_file.name - for rag_resource in self.vertex_rag_store.rag_resources: + + if not self._vertex_rag_store.rag_resources: + raise ValueError("Rag resources must be set.") + + for rag_resource in self._vertex_rag_store.rag_resources: rag.upload_file( corpus_name=rag_resource.rag_corpus, path=temp_file_path, @@ -97,10 +105,10 @@ class VertexAiRagMemoryService(BaseMemoryService): """Searches for sessions that match the query using rag.retrieval_query.""" response = rag.retrieval_query( text=query, - rag_resources=self.vertex_rag_store.rag_resources, - rag_corpora=self.vertex_rag_store.rag_corpora, - similarity_top_k=self.vertex_rag_store.similarity_top_k, - vector_distance_threshold=self.vertex_rag_store.vector_distance_threshold, + rag_resources=self._vertex_rag_store.rag_resources, + rag_corpora=self._vertex_rag_store.rag_corpora, + similarity_top_k=self._vertex_rag_store.similarity_top_k, + vector_distance_threshold=self._vertex_rag_store.vector_distance_threshold, ) memory_results = [] @@ -144,9 +152,16 @@ class VertexAiRagMemoryService(BaseMemoryService): for session_id, event_lists in session_events_map.items(): for events in _merge_event_lists(event_lists): sorted_events = sorted(events, key=lambda e: e.timestamp) - memory_results.append( - MemoryResult(session_id=session_id, events=sorted_events) - ) + + memory_results.extend([ + MemoryEntry( + author=event.author, + content=event.content, + timestamp=_utils.format_timestamp(event.timestamp), + ) + for event in sorted_events + if event.content + ]) return SearchMemoryResponse(memories=memory_results) diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 4de3acc..3c384a8 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -18,7 +18,9 @@ import asyncio import logging import queue import threading -from typing import AsyncGenerator, Generator, Optional +from typing import AsyncGenerator +from typing import Generator +from typing import Optional import warnings from deprecated import deprecated diff --git a/src/google/adk/tools/_memory_entry_utils.py b/src/google/adk/tools/_memory_entry_utils.py new file mode 100644 index 0000000..80caf6d --- /dev/null +++ b/src/google/adk/tools/_memory_entry_utils.py @@ -0,0 +1,30 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..memory.memory_entry import MemoryEntry + + +def extract_text(memory: MemoryEntry, splitter: str = ' ') -> str: + """Extracts the text from the memory entry.""" + if not memory.content.parts: + return '' + return splitter.join( + [part.text for part in memory.content.parts if part.text] + ) diff --git a/src/google/adk/tools/load_memory_tool.py b/src/google/adk/tools/load_memory_tool.py index 3fe530b..e8702b5 100644 --- a/src/google/adk/tools/load_memory_tool.py +++ b/src/google/adk/tools/load_memory_tool.py @@ -17,19 +17,25 @@ from __future__ import annotations from typing import TYPE_CHECKING from google.genai import types +from openai import BaseModel +from pydantic import Field from typing_extensions import override +from ..memory.memory_entry import MemoryEntry from .function_tool import FunctionTool from .tool_context import ToolContext if TYPE_CHECKING: - from ..memory.base_memory_service import MemoryResult from ..models import LlmRequest +class LoadMemoryResponse(BaseModel): + memories: list[MemoryEntry] = Field(default_factory=list) + + async def load_memory( query: str, tool_context: ToolContext -) -> 'list[MemoryResult]': +) -> LoadMemoryResponse: """Loads the memory for the current user. Args: @@ -38,12 +44,15 @@ async def load_memory( Returns: A list of memory results. """ - response = await tool_context.search_memory(query) - return response.memories + search_memory_response = await tool_context.search_memory(query) + return LoadMemoryResponse(memories=search_memory_response.memories) class LoadMemoryTool(FunctionTool): - """A tool that loads the memory for the current user.""" + """A tool that loads the memory for the current user. + + NOTE: Currently this tool only uses text part from the memory. + """ def __init__(self): super().__init__(load_memory) diff --git a/src/google/adk/tools/preload_memory_tool.py b/src/google/adk/tools/preload_memory_tool.py index ddefc44..8aa24a2 100644 --- a/src/google/adk/tools/preload_memory_tool.py +++ b/src/google/adk/tools/preload_memory_tool.py @@ -14,11 +14,11 @@ from __future__ import annotations -from datetime import datetime from typing import TYPE_CHECKING from typing_extensions import override +from . import _memory_entry_utils from .base_tool import BaseTool from .tool_context import ToolContext @@ -27,7 +27,10 @@ if TYPE_CHECKING: class PreloadMemoryTool(BaseTool): - """A tool that preloads the memory for the current user.""" + """A tool that preloads the memory for the current user. + + NOTE: Currently this tool only uses text part from the memory. + """ def __init__(self): # Name and description are not used because this tool only @@ -41,29 +44,35 @@ class PreloadMemoryTool(BaseTool): tool_context: ToolContext, llm_request: LlmRequest, ) -> None: - parts = tool_context.user_content.parts - if not parts or not parts[0].text: + user_content = tool_context.user_content + if ( + not user_content + or not user_content.parts + or not user_content.parts[0].text + ): return - query = parts[0].text - response = await tool_context.search_memory(query) + + user_query: str = user_content.parts[0].text + response = await tool_context.search_memory(user_query) if not response.memories: return - memory_text = '' + + memory_text_lines = [] for memory in response.memories: - time_str = datetime.fromtimestamp(memory.events[0].timestamp).isoformat() - memory_text += f'Time: {time_str}\n' - for event in memory.events: - # TODO: support multi-part content. - if ( - event.content - and event.content.parts - and event.content.parts[0].text - ): - memory_text += f'{event.author}: {event.content.parts[0].text}\n' + if time_str := (f'Time: {memory.timestamp}' if memory.timestamp else ''): + memory_text_lines.append(time_str) + if memory_text := _memory_entry_utils.extract_text(memory): + memory_text_lines.append( + f'{memory.author}: {memory_text}' if memory.author else memory_text + ) + if not memory_text_lines: + return + + full_memory_text = '\n'.join(memory_text_lines) si = f"""The following content is from your previous conversations with the user. They may be useful for answering the user's current query. -{memory_text} +{full_memory_text} """ llm_request.append_instructions([si])