chore: Add functions to convert between storage and event classes

PiperOrigin-RevId: 766280876
This commit is contained in:
Google Team Member 2025-06-02 11:44:12 -07:00 committed by Copybara-Service
parent 15a45a68fd
commit 11b504c808

View File

@ -216,6 +216,55 @@ class StorageEvent(Base):
else:
self.long_running_tool_ids_json = json.dumps(list(value))
@classmethod
def from_event(cls, session: Session, event: Event) -> StorageEvent:
storage_event = StorageEvent(
id=event.id,
invocation_id=event.invocation_id,
author=event.author,
branch=event.branch,
actions=event.actions,
session_id=session.id,
app_name=session.app_name,
user_id=session.user_id,
timestamp=datetime.fromtimestamp(event.timestamp),
long_running_tool_ids=event.long_running_tool_ids,
partial=event.partial,
turn_complete=event.turn_complete,
error_code=event.error_code,
error_message=event.error_message,
interrupted=event.interrupted,
)
if event.content:
storage_event.content = event.content.model_dump(
exclude_none=True, mode="json"
)
if event.grounding_metadata:
storage_event.grounding_metadata = event.grounding_metadata.model_dump(
exclude_none=True, mode="json"
)
return storage_event
def to_event(self) -> Event:
return Event(
id=self.id,
invocation_id=self.invocation_id,
author=self.author,
branch=self.branch,
actions=self.actions,
timestamp=self.timestamp.timestamp(),
content=_session_util.decode_content(self.content),
long_running_tool_ids=self.long_running_tool_ids,
partial=self.partial,
turn_complete=self.turn_complete,
error_code=self.error_code,
error_message=self.error_message,
interrupted=self.interrupted,
grounding_metadata=_session_util.decode_grounding_metadata(
self.grounding_metadata
),
)
class StorageAppState(Base):
"""Represents an app state stored in the database."""
@ -426,27 +475,7 @@ class DatabaseSessionService(BaseSessionService):
state=merged_state,
last_update_time=storage_session.update_time.timestamp(),
)
session.events = [
Event(
id=e.id,
author=e.author,
branch=e.branch,
invocation_id=e.invocation_id,
content=_session_util.decode_content(e.content),
actions=e.actions,
timestamp=e.timestamp.timestamp(),
long_running_tool_ids=e.long_running_tool_ids,
grounding_metadata=_session_util.decode_grounding_metadata(
e.grounding_metadata
),
partial=e.partial,
turn_complete=e.turn_complete,
error_code=e.error_code,
error_message=e.error_message,
interrupted=e.interrupted,
)
for e in reversed(storage_events)
]
session.events = [e.to_event() for e in reversed(storage_events)]
return session
@override
@ -542,33 +571,7 @@ class DatabaseSessionService(BaseSessionService):
session_state.update(session_state_delta)
storage_session.state = session_state
storage_event = StorageEvent(
id=event.id,
invocation_id=event.invocation_id,
author=event.author,
branch=event.branch,
actions=event.actions,
session_id=session.id,
app_name=session.app_name,
user_id=session.user_id,
timestamp=datetime.fromtimestamp(event.timestamp),
long_running_tool_ids=event.long_running_tool_ids,
partial=event.partial,
turn_complete=event.turn_complete,
error_code=event.error_code,
error_message=event.error_message,
interrupted=event.interrupted,
)
if event.content:
storage_event.content = event.content.model_dump(
exclude_none=True, mode="json"
)
if event.grounding_metadata:
storage_event.grounding_metadata = event.grounding_metadata.model_dump(
exclude_none=True, mode="json"
)
session_factory.add(storage_event)
session_factory.add(StorageEvent.from_event(session, event))
session_factory.commit()
session_factory.refresh(storage_session)
@ -581,19 +584,6 @@ class DatabaseSessionService(BaseSessionService):
return event
def convert_event(event: StorageEvent) -> Event:
"""Converts a storage event to an event."""
return Event(
id=event.id,
author=event.author,
branch=event.branch,
invocation_id=event.invocation_id,
content=event.content,
actions=event.actions,
timestamp=event.timestamp.timestamp(),
)
def _extract_state_delta(state: dict[str, Any]):
app_state_delta = {}
user_state_delta = {}