feat(memory)!: Uses the new MemoryEntry schema for all memory related components.

BREAKING CHANGE. This commit changes all memory related interface to using the newly introduced MemoryEntry class.

PiperOrigin-RevId: 758464887
This commit is contained in:
Wei Sun (Jack)
2025-05-13 19:06:16 -07:00
committed by Copybara-Service
parent 825f5d4f2e
commit 30947b48b8
10 changed files with 334 additions and 90 deletions

View File

@@ -0,0 +1,30 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..memory.memory_entry import MemoryEntry
def extract_text(memory: MemoryEntry, splitter: str = ' ') -> str:
"""Extracts the text from the memory entry."""
if not memory.content.parts:
return ''
return splitter.join(
[part.text for part in memory.content.parts if part.text]
)

View File

@@ -17,19 +17,25 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from google.genai import types
from openai import BaseModel
from pydantic import Field
from typing_extensions import override
from ..memory.memory_entry import MemoryEntry
from .function_tool import FunctionTool
from .tool_context import ToolContext
if TYPE_CHECKING:
from ..memory.base_memory_service import MemoryResult
from ..models import LlmRequest
class LoadMemoryResponse(BaseModel):
memories: list[MemoryEntry] = Field(default_factory=list)
async def load_memory(
query: str, tool_context: ToolContext
) -> 'list[MemoryResult]':
) -> LoadMemoryResponse:
"""Loads the memory for the current user.
Args:
@@ -38,12 +44,15 @@ async def load_memory(
Returns:
A list of memory results.
"""
response = await tool_context.search_memory(query)
return response.memories
search_memory_response = await tool_context.search_memory(query)
return LoadMemoryResponse(memories=search_memory_response.memories)
class LoadMemoryTool(FunctionTool):
"""A tool that loads the memory for the current user."""
"""A tool that loads the memory for the current user.
NOTE: Currently this tool only uses text part from the memory.
"""
def __init__(self):
super().__init__(load_memory)

View File

@@ -14,11 +14,11 @@
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from typing_extensions import override
from . import _memory_entry_utils
from .base_tool import BaseTool
from .tool_context import ToolContext
@@ -27,7 +27,10 @@ if TYPE_CHECKING:
class PreloadMemoryTool(BaseTool):
"""A tool that preloads the memory for the current user."""
"""A tool that preloads the memory for the current user.
NOTE: Currently this tool only uses text part from the memory.
"""
def __init__(self):
# Name and description are not used because this tool only
@@ -41,29 +44,35 @@ class PreloadMemoryTool(BaseTool):
tool_context: ToolContext,
llm_request: LlmRequest,
) -> None:
parts = tool_context.user_content.parts
if not parts or not parts[0].text:
user_content = tool_context.user_content
if (
not user_content
or not user_content.parts
or not user_content.parts[0].text
):
return
query = parts[0].text
response = await tool_context.search_memory(query)
user_query: str = user_content.parts[0].text
response = await tool_context.search_memory(user_query)
if not response.memories:
return
memory_text = ''
memory_text_lines = []
for memory in response.memories:
time_str = datetime.fromtimestamp(memory.events[0].timestamp).isoformat()
memory_text += f'Time: {time_str}\n'
for event in memory.events:
# TODO: support multi-part content.
if (
event.content
and event.content.parts
and event.content.parts[0].text
):
memory_text += f'{event.author}: {event.content.parts[0].text}\n'
if time_str := (f'Time: {memory.timestamp}' if memory.timestamp else ''):
memory_text_lines.append(time_str)
if memory_text := _memory_entry_utils.extract_text(memory):
memory_text_lines.append(
f'{memory.author}: {memory_text}' if memory.author else memory_text
)
if not memory_text_lines:
return
full_memory_text = '\n'.join(memory_text_lines)
si = f"""The following content is from your previous conversations with the user.
They may be useful for answering the user's current query.
<PAST_CONVERSATIONS>
{memory_text}
{full_memory_text}
</PAST_CONVERSATIONS>
"""
llm_request.append_instructions([si])