No public description

PiperOrigin-RevId: 748777998
This commit is contained in:
Google ADK Member
2025-04-17 19:50:22 +00:00
committed by hangfei
parent 290058eb05
commit 61d4be2d76
99 changed files with 2120 additions and 256 deletions

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import copy
from datetime import datetime
import json
@@ -20,17 +21,17 @@ from typing import Any
from typing import Optional
import uuid
from sqlalchemy import Boolean
from sqlalchemy import delete
from sqlalchemy import Dialect
from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import Text
from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.exc import ArgumentError
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
@@ -54,6 +55,7 @@ from .base_session_service import ListSessionsResponse
from .session import Session
from .state import State
logger = logging.getLogger(__name__)
@@ -103,7 +105,7 @@ class StorageSession(Base):
String, primary_key=True, default=lambda: str(uuid.uuid4())
)
state: Mapped[dict] = mapped_column(
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
@@ -134,8 +136,20 @@ class StorageEvent(Base):
author: Mapped[str] = mapped_column(String)
branch: Mapped[str] = mapped_column(String, nullable=True)
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
content: Mapped[dict] = mapped_column(DynamicJSON)
actions: Mapped[dict] = mapped_column(PickleType)
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON)
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
Text, nullable=True
)
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(
DynamicJSON, nullable=True
)
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
error_code: Mapped[str] = mapped_column(String, nullable=True)
error_message: Mapped[str] = mapped_column(String, nullable=True)
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
storage_session: Mapped[StorageSession] = relationship(
"StorageSession",
@@ -150,13 +164,28 @@ class StorageEvent(Base):
),
)
@property
def long_running_tool_ids(self) -> set[str]:
return (
set(json.loads(self.long_running_tool_ids_json))
if self.long_running_tool_ids_json
else set()
)
@long_running_tool_ids.setter
def long_running_tool_ids(self, value: set[str]):
if value is None:
self.long_running_tool_ids_json = None
else:
self.long_running_tool_ids_json = json.dumps(list(value))
class StorageAppState(Base):
"""Represents an app state stored in the database."""
__tablename__ = "app_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
state: Mapped[dict] = mapped_column(
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
update_time: Mapped[DateTime] = mapped_column(
@@ -170,7 +199,7 @@ class StorageUserState(Base):
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
state: Mapped[dict] = mapped_column(
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
update_time: Mapped[DateTime] = mapped_column(
@@ -295,7 +324,6 @@ class DatabaseSessionService(BaseSessionService):
last_update_time=storage_session.update_time.timestamp(),
)
return session
return None
@override
def get_session(
@@ -309,7 +337,6 @@ class DatabaseSessionService(BaseSessionService):
# 1. Get the storage session entry from session table
# 2. Get all the events based on session id and filtering config
# 3. Convert and return the session
session: Session = None
with self.DatabaseSessionFactory() as sessionFactory:
storage_session = sessionFactory.get(
StorageSession, (app_name, user_id, session_id)
@@ -356,13 +383,19 @@ class DatabaseSessionService(BaseSessionService):
author=e.author,
branch=e.branch,
invocation_id=e.invocation_id,
content=e.content,
content=_decode_content(e.content),
actions=e.actions,
timestamp=e.timestamp.timestamp(),
long_running_tool_ids=e.long_running_tool_ids,
grounding_metadata=e.grounding_metadata,
partial=e.partial,
turn_complete=e.turn_complete,
error_code=e.error_code,
error_message=e.error_message,
interrupted=e.interrupted,
)
for e in storage_events
]
return session
@override
@@ -387,7 +420,6 @@ class DatabaseSessionService(BaseSessionService):
)
sessions.append(session)
return ListSessionsResponse(sessions=sessions)
raise ValueError("Failed to retrieve sessions.")
@override
def delete_session(
@@ -406,7 +438,7 @@ class DatabaseSessionService(BaseSessionService):
def append_event(self, session: Session, event: Event) -> Event:
logger.info(f"Append event: {event} to session {session.id}")
if event.partial and not event.content:
if event.partial:
return event
# 1. Check if timestamp is stale
@@ -455,19 +487,34 @@ class DatabaseSessionService(BaseSessionService):
storage_user_state.state = user_state
storage_session.state = session_state
encoded_content = event.content.model_dump(exclude_none=True)
storage_event = StorageEvent(
id=event.id,
invocation_id=event.invocation_id,
author=event.author,
branch=event.branch,
content=encoded_content,
actions=event.actions,
session_id=session.id,
app_name=session.app_name,
user_id=session.user_id,
timestamp=datetime.fromtimestamp(event.timestamp),
long_running_tool_ids=event.long_running_tool_ids,
grounding_metadata=event.grounding_metadata,
partial=event.partial,
turn_complete=event.turn_complete,
error_code=event.error_code,
error_message=event.error_message,
interrupted=event.interrupted,
)
if event.content:
encoded_content = event.content.model_dump(exclude_none=True)
# Workaround for multimodal Content throwing JSON not serializable
# error with SQLAlchemy.
for p in encoded_content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = (
base64.b64encode(p["inline_data"]["data"]).decode("utf-8"),
)
storage_event.content = encoded_content
sessionFactory.add(storage_event)
@@ -489,8 +536,7 @@ class DatabaseSessionService(BaseSessionService):
user_id: str,
session_id: str,
) -> ListEventsResponse:
pass
raise NotImplementedError()
def convert_event(event: StorageEvent) -> Event:
"""Converts a storage event to an event."""
@@ -505,7 +551,7 @@ def convert_event(event: StorageEvent) -> Event:
)
def _extract_state_delta(state: dict):
def _extract_state_delta(state: dict[str, Any]):
app_state_delta = {}
user_state_delta = {}
session_state_delta = {}
@@ -528,3 +574,10 @@ def _merge_state(app_state, user_state, session_state):
for key in user_state.keys():
merged_state[State.USER_PREFIX + key] = user_state[key]
return merged_state
def _decode_content(content: dict[str, Any]) -> dict[str, Any]:
for p in content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
return content