From 11b504c808c3db5f9c93431dcb218e2af04958ae Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 2 Jun 2025 11:44:12 -0700 Subject: [PATCH] chore: Add functions to convert between storage and event classes PiperOrigin-RevId: 766280876 --- .../adk/sessions/database_session_service.py | 112 ++++++++---------- 1 file changed, 51 insertions(+), 61 deletions(-) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 4d965de..2ccd600 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -216,6 +216,55 @@ class StorageEvent(Base): else: self.long_running_tool_ids_json = json.dumps(list(value)) + @classmethod + def from_event(cls, session: Session, event: Event) -> StorageEvent: + storage_event = StorageEvent( + id=event.id, + invocation_id=event.invocation_id, + author=event.author, + branch=event.branch, + actions=event.actions, + session_id=session.id, + app_name=session.app_name, + user_id=session.user_id, + timestamp=datetime.fromtimestamp(event.timestamp), + long_running_tool_ids=event.long_running_tool_ids, + partial=event.partial, + turn_complete=event.turn_complete, + error_code=event.error_code, + error_message=event.error_message, + interrupted=event.interrupted, + ) + if event.content: + storage_event.content = event.content.model_dump( + exclude_none=True, mode="json" + ) + if event.grounding_metadata: + storage_event.grounding_metadata = event.grounding_metadata.model_dump( + exclude_none=True, mode="json" + ) + return storage_event + + def to_event(self) -> Event: + return Event( + id=self.id, + invocation_id=self.invocation_id, + author=self.author, + branch=self.branch, + actions=self.actions, + timestamp=self.timestamp.timestamp(), + content=_session_util.decode_content(self.content), + long_running_tool_ids=self.long_running_tool_ids, + partial=self.partial, + turn_complete=self.turn_complete, + error_code=self.error_code, + error_message=self.error_message, + interrupted=self.interrupted, + grounding_metadata=_session_util.decode_grounding_metadata( + self.grounding_metadata + ), + ) + class StorageAppState(Base): """Represents an app state stored in the database.""" @@ -426,27 +475,7 @@ class DatabaseSessionService(BaseSessionService): state=merged_state, last_update_time=storage_session.update_time.timestamp(), ) - session.events = [ - Event( - id=e.id, - author=e.author, - branch=e.branch, - invocation_id=e.invocation_id, - content=_session_util.decode_content(e.content), - actions=e.actions, - timestamp=e.timestamp.timestamp(), - long_running_tool_ids=e.long_running_tool_ids, - grounding_metadata=_session_util.decode_grounding_metadata( - e.grounding_metadata - ), - partial=e.partial, - turn_complete=e.turn_complete, - error_code=e.error_code, - error_message=e.error_message, - interrupted=e.interrupted, - ) - for e in reversed(storage_events) - ] + session.events = [e.to_event() for e in reversed(storage_events)] return session @override @@ -542,33 +571,7 @@ class DatabaseSessionService(BaseSessionService): session_state.update(session_state_delta) storage_session.state = session_state - storage_event = StorageEvent( - id=event.id, - invocation_id=event.invocation_id, - author=event.author, - branch=event.branch, - actions=event.actions, - session_id=session.id, - app_name=session.app_name, - user_id=session.user_id, - timestamp=datetime.fromtimestamp(event.timestamp), - long_running_tool_ids=event.long_running_tool_ids, - partial=event.partial, - turn_complete=event.turn_complete, - error_code=event.error_code, - error_message=event.error_message, - interrupted=event.interrupted, - ) - if event.content: - storage_event.content = event.content.model_dump( - exclude_none=True, mode="json" - ) - if event.grounding_metadata: - storage_event.grounding_metadata = event.grounding_metadata.model_dump( - exclude_none=True, mode="json" - ) - - session_factory.add(storage_event) + session_factory.add(StorageEvent.from_event(session, event)) session_factory.commit() session_factory.refresh(storage_session) @@ -581,19 +584,6 @@ class DatabaseSessionService(BaseSessionService): return event -def convert_event(event: StorageEvent) -> Event: - """Converts a storage event to an event.""" - return Event( - id=event.id, - author=event.author, - branch=event.branch, - invocation_id=event.invocation_id, - content=event.content, - actions=event.actions, - timestamp=event.timestamp.timestamp(), - ) - - def _extract_state_delta(state: dict[str, Any]): app_state_delta = {} user_state_delta = {}