From bf27f22a9534279b942bb8047d747effc9e7dd7a Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Wed, 28 May 2025 13:49:47 -0700 Subject: [PATCH] fix: Make GroundingMetadata JSON serializable. Also use the same logic to simplify content serialization. PiperOrigin-RevId: 764401248 --- src/google/adk/sessions/_session_util.py | 26 +++++++---------- .../adk/sessions/database_session_service.py | 14 +++++++-- .../adk/sessions/vertex_ai_session_service.py | 13 ++++++--- .../sessions/test_session_service.py | 29 ++++++++++--------- 4 files changed, 45 insertions(+), 37 deletions(-) diff --git a/src/google/adk/sessions/_session_util.py b/src/google/adk/sessions/_session_util.py index a55df7d..2cc6594 100644 --- a/src/google/adk/sessions/_session_util.py +++ b/src/google/adk/sessions/_session_util.py @@ -11,34 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Utility functions for session service.""" +from __future__ import annotations -import base64 from typing import Any from typing import Optional from google.genai import types -def encode_content(content: types.Content): - """Encodes a content object to a JSON dictionary.""" - encoded_content = content.model_dump(exclude_none=True) - for p in encoded_content["parts"]: - if "inline_data" in p: - p["inline_data"]["data"] = base64.b64encode( - p["inline_data"]["data"] - ).decode("utf-8") - return encoded_content - - def decode_content( content: Optional[dict[str, Any]], ) -> Optional[types.Content]: """Decodes a content object from a JSON dictionary.""" if not content: return None - for p in content["parts"]: - if "inline_data" in p: - p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"]) return types.Content.model_validate(content) + + +def decode_grounding_metadata( + grounding_metadata: Optional[dict[str, Any]], +) -> Optional[types.GroundingMetadata]: + """Decodes a grounding metadata object from a JSON dictionary.""" + if not grounding_metadata: + return None + return types.GroundingMetadata.model_validate(grounding_metadata) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 8770d93..7c64548 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -21,6 +21,7 @@ from typing import Any from typing import Optional import uuid +from google.genai import types from sqlalchemy import Boolean from sqlalchemy import delete from sqlalchemy import Dialect @@ -421,7 +422,9 @@ class DatabaseSessionService(BaseSessionService): actions=e.actions, timestamp=e.timestamp.timestamp(), long_running_tool_ids=e.long_running_tool_ids, - grounding_metadata=e.grounding_metadata, + grounding_metadata=_session_util.decode_grounding_metadata( + e.grounding_metadata + ), partial=e.partial, turn_complete=e.turn_complete, error_code=e.error_code, @@ -536,7 +539,6 @@ class DatabaseSessionService(BaseSessionService): user_id=session.user_id, timestamp=datetime.fromtimestamp(event.timestamp), long_running_tool_ids=event.long_running_tool_ids, - grounding_metadata=event.grounding_metadata, partial=event.partial, turn_complete=event.turn_complete, error_code=event.error_code, @@ -544,7 +546,13 @@ class DatabaseSessionService(BaseSessionService): interrupted=event.interrupted, ) if event.content: - storage_event.content = _session_util.encode_content(event.content) + storage_event.content = event.content.model_dump( + exclude_none=True, mode="json" + ) + if event.grounding_metadata: + storage_event.grounding_metadata = event.grounding_metadata.model_dump( + exclude_none=True, mode="json" + ) session_factory.add(storage_event) diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 1352728..a147bbe 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import asyncio import logging import re @@ -18,6 +20,7 @@ from typing import Any from typing import Optional from dateutil import parser +from google.genai import types from typing_extensions import override from google import genai @@ -256,7 +259,7 @@ def _convert_event_to_json(event: Event): } if event.grounding_metadata: metadata_json['grounding_metadata'] = event.grounding_metadata.model_dump( - exclude_none=True + exclude_none=True, mode='json' ) event_json = { @@ -284,7 +287,9 @@ def _convert_event_to_json(event: Event): } event_json['actions'] = actions_json if event.content: - event_json['content'] = _session_util.encode_content(event.content) + event_json['content'] = event.content.model_dump( + exclude_none=True, mode='json' + ) if event.error_code: event_json['error_code'] = event.error_code if event.error_message: @@ -325,8 +330,8 @@ def _from_api_event(api_event: dict) -> Event: event.turn_complete = api_event['eventMetadata'].get('turnComplete', None) event.interrupted = api_event['eventMetadata'].get('interrupted', None) event.branch = api_event['eventMetadata'].get('branch', None) - event.grounding_metadata = api_event['eventMetadata'].get( - 'groundingMetadata', None + event.grounding_metadata = _session_util.decode_grounding_metadata( + api_event['eventMetadata'].get('groundingMetadata', None) ) event.long_running_tool_ids = ( set(long_running_tool_ids_list) if long_running_tool_ids_list else None diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 676fb7d..ec93caa 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -246,32 +246,33 @@ async def test_append_event_bytes(service_type): session = await session_service.create_session( app_name=app_name, user_id=user_id ) + + test_content = types.Content( + role='user', + parts=[ + types.Part.from_bytes(data=b'test_image_data', mime_type='image/png'), + ], + ) + test_grounding_metadata = types.GroundingMetadata( + search_entry_point=types.SearchEntryPoint(sdk_blob=b'test_sdk_blob') + ) event = Event( invocation_id='invocation', author='user', - content=types.Content( - role='user', - parts=[ - types.Part.from_bytes( - data=b'test_image_data', mime_type='image/png' - ), - ], - ), + content=test_content, + grounding_metadata=test_grounding_metadata, ) await session_service.append_event(session=session, event=event) - assert session.events[0].content.parts[0] == types.Part.from_bytes( - data=b'test_image_data', mime_type='image/png' - ) + assert session.events[0].content == test_content session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) events = session.events assert len(events) == 1 - assert events[0].content.parts[0] == types.Part.from_bytes( - data=b'test_image_data', mime_type='image/png' - ) + assert events[0].content == test_content + assert events[0].grounding_metadata == test_grounding_metadata @pytest.mark.asyncio