feat(memory)!: Uses the new MemoryEntry schema for all memory related components.

BREAKING CHANGE. This commit changes all memory related interface to using the newly introduced MemoryEntry class.

PiperOrigin-RevId: 758464887
This commit is contained in:
Wei Sun (Jack) 2025-05-13 19:06:16 -07:00 committed by Copybara-Service
parent 825f5d4f2e
commit 30947b48b8
10 changed files with 334 additions and 90 deletions

View File

@ -12,20 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from datetime import datetime
from google.adk import Agent
from google.adk.agents.callback_context import CallbackContext
from google.adk.tools.load_memory_tool import load_memory_tool
from google.adk.tools.preload_memory_tool import preload_memory_tool
from google.genai import types
def update_current_time(callback_context: CallbackContext):
callback_context.state['_time'] = datetime.now().isoformat()
root_agent = Agent(
model='gemini-2.0-flash-exp',
model='gemini-2.0-flash-001',
name='memory_agent',
description='agent that have access to memory tools.',
instruction="""
You are an agent that help user answer questions.
""",
tools=[load_memory_tool, preload_memory_tool],
before_agent_callback=update_current_time,
instruction="""\
You are an agent that help user answer questions.
Current time: {_time}
""",
tools=[
load_memory_tool,
preload_memory_tool,
],
)

View File

@ -0,0 +1,111 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from datetime import datetime
from datetime import timedelta
from typing import cast
import warnings
import agent
from dotenv import load_dotenv
from google.adk.cli.utils import logs
from google.adk.runners import InMemoryRunner
from google.adk.sessions import Session
from google.genai import types
load_dotenv(override=True)
warnings.filterwarnings('ignore', category=UserWarning)
logs.log_to_tmp_folder()
async def main():
app_name = 'my_app'
user_id_1 = 'user1'
runner = InMemoryRunner(
app_name=app_name,
agent=agent.root_agent,
)
async def run_prompt(session: Session, new_message: str) -> Session:
content = types.Content(
role='user', parts=[types.Part.from_text(text=new_message)]
)
print('** User says:', content.model_dump(exclude_none=True))
async for event in runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
):
if not event.content or not event.content.parts:
continue
if event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
elif event.content.parts[0].function_call:
print(
f'** {event.author}: fc /'
f' {event.content.parts[0].function_call.name} /'
f' {event.content.parts[0].function_call.args}\n'
)
elif event.content.parts[0].function_response:
print(
f'** {event.author}: fr /'
f' {event.content.parts[0].function_response.name} /'
f' {event.content.parts[0].function_response.response}\n'
)
return cast(
Session,
runner.session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session.id
),
)
session_1 = runner.session_service.create_session(
app_name=app_name, user_id=user_id_1
)
print(f'----Session to create memory: {session_1.id} ----------------------')
session_1 = await run_prompt(session_1, 'Hi')
session_1 = await run_prompt(session_1, 'My name is Jack')
session_1 = await run_prompt(session_1, 'I like badminton.')
session_1 = await run_prompt(
session_1,
f'I ate a burger on {(datetime.now() - timedelta(days=1)).date()}.',
)
session_1 = await run_prompt(
session_1,
f'I ate a banana on {(datetime.now() - timedelta(days=2)).date()}.',
)
print('Saving session to memory service...')
if runner.memory_service:
await runner.memory_service.add_session_to_memory(session_1)
print('-------------------------------------------------------------------')
session_2 = runner.session_service.create_session(
app_name=app_name, user_id=user_id_1
)
print(f'----Session to use memory: {session_2.id} ----------------------')
session_2 = await run_prompt(session_2, 'Hi')
session_2 = await run_prompt(session_2, 'What do I like to do?')
# ** memory_agent: You like badminton.
session_2 = await run_prompt(session_2, 'When did I say that?')
# ** memory_agent: You said you liked badminton on ...
session_2 = await run_prompt(session_2, 'What did I eat yesterday?')
# ** memory_agent: You ate a burger yesterday...
print('-------------------------------------------------------------------')
if __name__ == '__main__':
asyncio.run(main())

View File

@ -0,0 +1,23 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from datetime import datetime
def format_timestamp(timestamp: float) -> str:
"""Formats the timestamp of the memory entry."""
return datetime.fromtimestamp(timestamp).isoformat()

View File

@ -12,46 +12,44 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from __future__ import annotations
from abc import ABC
from abc import abstractmethod
from typing import TYPE_CHECKING
from pydantic import BaseModel
from pydantic import Field
from ..events.event import Event
from ..sessions.session import Session
from .memory_entry import MemoryEntry
class MemoryResult(BaseModel):
"""Represents a single memory retrieval result.
Attributes:
session_id: The session id associated with the memory.
events: A list of events in the session.
"""
session_id: str
events: list[Event]
if TYPE_CHECKING:
from ..sessions.session import Session
class SearchMemoryResponse(BaseModel):
"""Represents the response from a memory search.
Attributes:
memories: A list of memory results matching the search query.
memories: A list of memory entries that relate to the search query.
"""
memories: list[MemoryResult] = Field(default_factory=list)
memories: list[MemoryEntry] = Field(default_factory=list)
class BaseMemoryService(abc.ABC):
class BaseMemoryService(ABC):
"""Base class for memory services.
The service provides functionalities to ingest sessions into memory so that
the memory can be used for user queries.
"""
@abc.abstractmethod
async def add_session_to_memory(self, session: Session):
@abstractmethod
async def add_session_to_memory(
self,
session: Session,
):
"""Adds a session to the memory service.
A session may be added multiple times during its lifetime.
@ -60,9 +58,13 @@ class BaseMemoryService(abc.ABC):
session: The session to add.
"""
@abc.abstractmethod
@abstractmethod
async def search_memory(
self, *, app_name: str, user_id: str, query: str
self,
*,
app_name: str,
user_id: str,
query: str,
) -> SearchMemoryResponse:
"""Searches for sessions that match the query.

View File

@ -12,11 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..events.event import Event
from ..sessions.session import Session
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from typing_extensions import override
from . import _utils
from .base_memory_service import BaseMemoryService
from .base_memory_service import MemoryResult
from .base_memory_service import SearchMemoryResponse
from .memory_entry import MemoryEntry
if TYPE_CHECKING:
from ..events.event import Event
from ..sessions.session import Session
def _user_key(app_name: str, user_id: str):
return f'{app_name}/{user_id}'
def _extract_words_lower(text: str) -> set[str]:
"""Extracts words from a string and converts them to lowercase."""
return set([word.lower() for word in re.findall(r'[A-Za-z]+', text)])
class InMemoryMemoryService(BaseMemoryService):
@ -26,37 +46,49 @@ class InMemoryMemoryService(BaseMemoryService):
"""
def __init__(self):
self.session_events: dict[str, list[Event]] = {}
"""keys are app_name/user_id/session_id"""
self._session_events: dict[str, dict[str, list[Event]]] = {}
"""Keys are app_name/user_id, session_id. Values are session event lists."""
@override
async def add_session_to_memory(self, session: Session):
key = f'{session.app_name}/{session.user_id}/{session.id}'
self.session_events[key] = [
event for event in session.events if event.content
user_key = _user_key(session.app_name, session.user_id)
self._session_events[user_key] = self._session_events.get(
_user_key(session.app_name, session.user_id), {}
)
self._session_events[user_key][session.id] = [
event
for event in session.events
if event.content and event.content.parts
]
@override
async def search_memory(
self, *, app_name: str, user_id: str, query: str
) -> SearchMemoryResponse:
"""Prototyping purpose only."""
keywords = set(query.lower().split())
user_key = _user_key(app_name, user_id)
if user_key not in self._session_events:
return SearchMemoryResponse()
words_in_query = set(query.lower().split())
response = SearchMemoryResponse()
for key, events in self.session_events.items():
if not key.startswith(f'{app_name}/{user_id}/'):
continue
matched_events = []
for event in events:
for session_events in self._session_events[user_key].values():
for event in session_events:
if not event.content or not event.content.parts:
continue
parts = event.content.parts
text = '\n'.join([part.text for part in parts if part.text]).lower()
for keyword in keywords:
if keyword in text:
matched_events.append(event)
break
if matched_events:
session_id = key.split('/')[-1]
response.memories.append(
MemoryResult(session_id=session_id, events=matched_events)
words_in_event = _extract_words_lower(
' '.join([part.text for part in event.content.parts if part.text])
)
if not words_in_event:
continue
if any(query_word in words_in_event for query_word in words_in_query):
response.memories.append(
MemoryEntry(
content=event.content,
author=event.author,
timestamp=_utils.format_timestamp(event.timestamp),
)
)
return response

View File

@ -16,6 +16,7 @@ from collections import OrderedDict
import json
import os
import tempfile
from typing import Optional
from google.genai import types
from typing_extensions import override
@ -23,9 +24,10 @@ from vertexai.preview import rag
from ..events.event import Event
from ..sessions.session import Session
from . import _utils
from .base_memory_service import BaseMemoryService
from .base_memory_service import MemoryResult
from .base_memory_service import SearchMemoryResponse
from .memory_entry import MemoryEntry
class VertexAiRagMemoryService(BaseMemoryService):
@ -33,8 +35,8 @@ class VertexAiRagMemoryService(BaseMemoryService):
def __init__(
self,
rag_corpus: str = None,
similarity_top_k: int = None,
rag_corpus: Optional[str] = None,
similarity_top_k: Optional[int] = None,
vector_distance_threshold: float = 10,
):
"""Initializes a VertexAiRagMemoryService.
@ -47,8 +49,10 @@ class VertexAiRagMemoryService(BaseMemoryService):
vector_distance_threshold: Only returns contexts with vector distance
smaller than the threshold..
"""
self.vertex_rag_store = types.VertexRagStore(
rag_resources=[rag.RagResource(rag_corpus=rag_corpus)],
self._vertex_rag_store = types.VertexRagStore(
rag_resources=[
types.VertexRagStoreRagResource(rag_corpus=rag_corpus),
],
similarity_top_k=similarity_top_k,
vector_distance_threshold=vector_distance_threshold,
)
@ -79,7 +83,11 @@ class VertexAiRagMemoryService(BaseMemoryService):
output_string = "\n".join(output_lines)
temp_file.write(output_string)
temp_file_path = temp_file.name
for rag_resource in self.vertex_rag_store.rag_resources:
if not self._vertex_rag_store.rag_resources:
raise ValueError("Rag resources must be set.")
for rag_resource in self._vertex_rag_store.rag_resources:
rag.upload_file(
corpus_name=rag_resource.rag_corpus,
path=temp_file_path,
@ -97,10 +105,10 @@ class VertexAiRagMemoryService(BaseMemoryService):
"""Searches for sessions that match the query using rag.retrieval_query."""
response = rag.retrieval_query(
text=query,
rag_resources=self.vertex_rag_store.rag_resources,
rag_corpora=self.vertex_rag_store.rag_corpora,
similarity_top_k=self.vertex_rag_store.similarity_top_k,
vector_distance_threshold=self.vertex_rag_store.vector_distance_threshold,
rag_resources=self._vertex_rag_store.rag_resources,
rag_corpora=self._vertex_rag_store.rag_corpora,
similarity_top_k=self._vertex_rag_store.similarity_top_k,
vector_distance_threshold=self._vertex_rag_store.vector_distance_threshold,
)
memory_results = []
@ -144,9 +152,16 @@ class VertexAiRagMemoryService(BaseMemoryService):
for session_id, event_lists in session_events_map.items():
for events in _merge_event_lists(event_lists):
sorted_events = sorted(events, key=lambda e: e.timestamp)
memory_results.append(
MemoryResult(session_id=session_id, events=sorted_events)
)
memory_results.extend([
MemoryEntry(
author=event.author,
content=event.content,
timestamp=_utils.format_timestamp(event.timestamp),
)
for event in sorted_events
if event.content
])
return SearchMemoryResponse(memories=memory_results)

View File

@ -18,7 +18,9 @@ import asyncio
import logging
import queue
import threading
from typing import AsyncGenerator, Generator, Optional
from typing import AsyncGenerator
from typing import Generator
from typing import Optional
import warnings
from deprecated import deprecated

View File

@ -0,0 +1,30 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..memory.memory_entry import MemoryEntry
def extract_text(memory: MemoryEntry, splitter: str = ' ') -> str:
"""Extracts the text from the memory entry."""
if not memory.content.parts:
return ''
return splitter.join(
[part.text for part in memory.content.parts if part.text]
)

View File

@ -17,19 +17,25 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from google.genai import types
from openai import BaseModel
from pydantic import Field
from typing_extensions import override
from ..memory.memory_entry import MemoryEntry
from .function_tool import FunctionTool
from .tool_context import ToolContext
if TYPE_CHECKING:
from ..memory.base_memory_service import MemoryResult
from ..models import LlmRequest
class LoadMemoryResponse(BaseModel):
memories: list[MemoryEntry] = Field(default_factory=list)
async def load_memory(
query: str, tool_context: ToolContext
) -> 'list[MemoryResult]':
) -> LoadMemoryResponse:
"""Loads the memory for the current user.
Args:
@ -38,12 +44,15 @@ async def load_memory(
Returns:
A list of memory results.
"""
response = await tool_context.search_memory(query)
return response.memories
search_memory_response = await tool_context.search_memory(query)
return LoadMemoryResponse(memories=search_memory_response.memories)
class LoadMemoryTool(FunctionTool):
"""A tool that loads the memory for the current user."""
"""A tool that loads the memory for the current user.
NOTE: Currently this tool only uses text part from the memory.
"""
def __init__(self):
super().__init__(load_memory)

View File

@ -14,11 +14,11 @@
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from typing_extensions import override
from . import _memory_entry_utils
from .base_tool import BaseTool
from .tool_context import ToolContext
@ -27,7 +27,10 @@ if TYPE_CHECKING:
class PreloadMemoryTool(BaseTool):
"""A tool that preloads the memory for the current user."""
"""A tool that preloads the memory for the current user.
NOTE: Currently this tool only uses text part from the memory.
"""
def __init__(self):
# Name and description are not used because this tool only
@ -41,29 +44,35 @@ class PreloadMemoryTool(BaseTool):
tool_context: ToolContext,
llm_request: LlmRequest,
) -> None:
parts = tool_context.user_content.parts
if not parts or not parts[0].text:
user_content = tool_context.user_content
if (
not user_content
or not user_content.parts
or not user_content.parts[0].text
):
return
query = parts[0].text
response = await tool_context.search_memory(query)
user_query: str = user_content.parts[0].text
response = await tool_context.search_memory(user_query)
if not response.memories:
return
memory_text = ''
memory_text_lines = []
for memory in response.memories:
time_str = datetime.fromtimestamp(memory.events[0].timestamp).isoformat()
memory_text += f'Time: {time_str}\n'
for event in memory.events:
# TODO: support multi-part content.
if (
event.content
and event.content.parts
and event.content.parts[0].text
):
memory_text += f'{event.author}: {event.content.parts[0].text}\n'
if time_str := (f'Time: {memory.timestamp}' if memory.timestamp else ''):
memory_text_lines.append(time_str)
if memory_text := _memory_entry_utils.extract_text(memory):
memory_text_lines.append(
f'{memory.author}: {memory_text}' if memory.author else memory_text
)
if not memory_text_lines:
return
full_memory_text = '\n'.join(memory_text_lines)
si = f"""The following content is from your previous conversations with the user.
They may be useful for answering the user's current query.
<PAST_CONVERSATIONS>
{memory_text}
{full_memory_text}
</PAST_CONVERSATIONS>
"""
llm_request.append_instructions([si])