fix: Make GroundingMetadata JSON serializable. Also use the same logic to simplify content serialization.

PiperOrigin-RevId: 764401248
This commit is contained in:
Shangjie Chen 2025-05-28 13:49:47 -07:00 committed by Copybara-Service
parent 7fc09b2c64
commit bf27f22a95
4 changed files with 45 additions and 37 deletions

View File

@ -11,34 +11,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Utility functions for session service.""" """Utility functions for session service."""
from __future__ import annotations
import base64
from typing import Any from typing import Any
from typing import Optional from typing import Optional
from google.genai import types from google.genai import types
def encode_content(content: types.Content):
"""Encodes a content object to a JSON dictionary."""
encoded_content = content.model_dump(exclude_none=True)
for p in encoded_content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64encode(
p["inline_data"]["data"]
).decode("utf-8")
return encoded_content
def decode_content( def decode_content(
content: Optional[dict[str, Any]], content: Optional[dict[str, Any]],
) -> Optional[types.Content]: ) -> Optional[types.Content]:
"""Decodes a content object from a JSON dictionary.""" """Decodes a content object from a JSON dictionary."""
if not content: if not content:
return None return None
for p in content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
return types.Content.model_validate(content) return types.Content.model_validate(content)
def decode_grounding_metadata(
grounding_metadata: Optional[dict[str, Any]],
) -> Optional[types.GroundingMetadata]:
"""Decodes a grounding metadata object from a JSON dictionary."""
if not grounding_metadata:
return None
return types.GroundingMetadata.model_validate(grounding_metadata)

View File

@ -21,6 +21,7 @@ from typing import Any
from typing import Optional from typing import Optional
import uuid import uuid
from google.genai import types
from sqlalchemy import Boolean from sqlalchemy import Boolean
from sqlalchemy import delete from sqlalchemy import delete
from sqlalchemy import Dialect from sqlalchemy import Dialect
@ -421,7 +422,9 @@ class DatabaseSessionService(BaseSessionService):
actions=e.actions, actions=e.actions,
timestamp=e.timestamp.timestamp(), timestamp=e.timestamp.timestamp(),
long_running_tool_ids=e.long_running_tool_ids, long_running_tool_ids=e.long_running_tool_ids,
grounding_metadata=e.grounding_metadata, grounding_metadata=_session_util.decode_grounding_metadata(
e.grounding_metadata
),
partial=e.partial, partial=e.partial,
turn_complete=e.turn_complete, turn_complete=e.turn_complete,
error_code=e.error_code, error_code=e.error_code,
@ -536,7 +539,6 @@ class DatabaseSessionService(BaseSessionService):
user_id=session.user_id, user_id=session.user_id,
timestamp=datetime.fromtimestamp(event.timestamp), timestamp=datetime.fromtimestamp(event.timestamp),
long_running_tool_ids=event.long_running_tool_ids, long_running_tool_ids=event.long_running_tool_ids,
grounding_metadata=event.grounding_metadata,
partial=event.partial, partial=event.partial,
turn_complete=event.turn_complete, turn_complete=event.turn_complete,
error_code=event.error_code, error_code=event.error_code,
@ -544,7 +546,13 @@ class DatabaseSessionService(BaseSessionService):
interrupted=event.interrupted, interrupted=event.interrupted,
) )
if event.content: if event.content:
storage_event.content = _session_util.encode_content(event.content) storage_event.content = event.content.model_dump(
exclude_none=True, mode="json"
)
if event.grounding_metadata:
storage_event.grounding_metadata = event.grounding_metadata.model_dump(
exclude_none=True, mode="json"
)
session_factory.add(storage_event) session_factory.add(storage_event)

View File

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations
import asyncio import asyncio
import logging import logging
import re import re
@ -18,6 +20,7 @@ from typing import Any
from typing import Optional from typing import Optional
from dateutil import parser from dateutil import parser
from google.genai import types
from typing_extensions import override from typing_extensions import override
from google import genai from google import genai
@ -256,7 +259,7 @@ def _convert_event_to_json(event: Event):
} }
if event.grounding_metadata: if event.grounding_metadata:
metadata_json['grounding_metadata'] = event.grounding_metadata.model_dump( metadata_json['grounding_metadata'] = event.grounding_metadata.model_dump(
exclude_none=True exclude_none=True, mode='json'
) )
event_json = { event_json = {
@ -284,7 +287,9 @@ def _convert_event_to_json(event: Event):
} }
event_json['actions'] = actions_json event_json['actions'] = actions_json
if event.content: if event.content:
event_json['content'] = _session_util.encode_content(event.content) event_json['content'] = event.content.model_dump(
exclude_none=True, mode='json'
)
if event.error_code: if event.error_code:
event_json['error_code'] = event.error_code event_json['error_code'] = event.error_code
if event.error_message: if event.error_message:
@ -325,8 +330,8 @@ def _from_api_event(api_event: dict) -> Event:
event.turn_complete = api_event['eventMetadata'].get('turnComplete', None) event.turn_complete = api_event['eventMetadata'].get('turnComplete', None)
event.interrupted = api_event['eventMetadata'].get('interrupted', None) event.interrupted = api_event['eventMetadata'].get('interrupted', None)
event.branch = api_event['eventMetadata'].get('branch', None) event.branch = api_event['eventMetadata'].get('branch', None)
event.grounding_metadata = api_event['eventMetadata'].get( event.grounding_metadata = _session_util.decode_grounding_metadata(
'groundingMetadata', None api_event['eventMetadata'].get('groundingMetadata', None)
) )
event.long_running_tool_ids = ( event.long_running_tool_ids = (
set(long_running_tool_ids_list) if long_running_tool_ids_list else None set(long_running_tool_ids_list) if long_running_tool_ids_list else None

View File

@ -246,32 +246,33 @@ async def test_append_event_bytes(service_type):
session = await session_service.create_session( session = await session_service.create_session(
app_name=app_name, user_id=user_id app_name=app_name, user_id=user_id
) )
test_content = types.Content(
role='user',
parts=[
types.Part.from_bytes(data=b'test_image_data', mime_type='image/png'),
],
)
test_grounding_metadata = types.GroundingMetadata(
search_entry_point=types.SearchEntryPoint(sdk_blob=b'test_sdk_blob')
)
event = Event( event = Event(
invocation_id='invocation', invocation_id='invocation',
author='user', author='user',
content=types.Content( content=test_content,
role='user', grounding_metadata=test_grounding_metadata,
parts=[
types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png'
),
],
),
) )
await session_service.append_event(session=session, event=event) await session_service.append_event(session=session, event=event)
assert session.events[0].content.parts[0] == types.Part.from_bytes( assert session.events[0].content == test_content
data=b'test_image_data', mime_type='image/png'
)
session = await session_service.get_session( session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id app_name=app_name, user_id=user_id, session_id=session.id
) )
events = session.events events = session.events
assert len(events) == 1 assert len(events) == 1
assert events[0].content.parts[0] == types.Part.from_bytes( assert events[0].content == test_content
data=b'test_image_data', mime_type='image/png' assert events[0].grounding_metadata == test_grounding_metadata
)
@pytest.mark.asyncio @pytest.mark.asyncio