fix: Make GroundingMetadata JSON serializable. Also use the same logic to simplify content serialization.

PiperOrigin-RevId: 764401248
This commit is contained in:
Shangjie Chen 2025-05-28 13:49:47 -07:00 committed by Copybara-Service
parent 7fc09b2c64
commit bf27f22a95
4 changed files with 45 additions and 37 deletions

View File

@ -11,34 +11,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for session service."""
from __future__ import annotations
import base64
from typing import Any
from typing import Optional
from google.genai import types
def encode_content(content: types.Content):
"""Encodes a content object to a JSON dictionary."""
encoded_content = content.model_dump(exclude_none=True)
for p in encoded_content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64encode(
p["inline_data"]["data"]
).decode("utf-8")
return encoded_content
def decode_content(
content: Optional[dict[str, Any]],
) -> Optional[types.Content]:
"""Decodes a content object from a JSON dictionary."""
if not content:
return None
for p in content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
return types.Content.model_validate(content)
def decode_grounding_metadata(
grounding_metadata: Optional[dict[str, Any]],
) -> Optional[types.GroundingMetadata]:
"""Decodes a grounding metadata object from a JSON dictionary."""
if not grounding_metadata:
return None
return types.GroundingMetadata.model_validate(grounding_metadata)

View File

@ -21,6 +21,7 @@ from typing import Any
from typing import Optional
import uuid
from google.genai import types
from sqlalchemy import Boolean
from sqlalchemy import delete
from sqlalchemy import Dialect
@ -421,7 +422,9 @@ class DatabaseSessionService(BaseSessionService):
actions=e.actions,
timestamp=e.timestamp.timestamp(),
long_running_tool_ids=e.long_running_tool_ids,
grounding_metadata=e.grounding_metadata,
grounding_metadata=_session_util.decode_grounding_metadata(
e.grounding_metadata
),
partial=e.partial,
turn_complete=e.turn_complete,
error_code=e.error_code,
@ -536,7 +539,6 @@ class DatabaseSessionService(BaseSessionService):
user_id=session.user_id,
timestamp=datetime.fromtimestamp(event.timestamp),
long_running_tool_ids=event.long_running_tool_ids,
grounding_metadata=event.grounding_metadata,
partial=event.partial,
turn_complete=event.turn_complete,
error_code=event.error_code,
@ -544,7 +546,13 @@ class DatabaseSessionService(BaseSessionService):
interrupted=event.interrupted,
)
if event.content:
storage_event.content = _session_util.encode_content(event.content)
storage_event.content = event.content.model_dump(
exclude_none=True, mode="json"
)
if event.grounding_metadata:
storage_event.grounding_metadata = event.grounding_metadata.model_dump(
exclude_none=True, mode="json"
)
session_factory.add(storage_event)

View File

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import logging
import re
@ -18,6 +20,7 @@ from typing import Any
from typing import Optional
from dateutil import parser
from google.genai import types
from typing_extensions import override
from google import genai
@ -256,7 +259,7 @@ def _convert_event_to_json(event: Event):
}
if event.grounding_metadata:
metadata_json['grounding_metadata'] = event.grounding_metadata.model_dump(
exclude_none=True
exclude_none=True, mode='json'
)
event_json = {
@ -284,7 +287,9 @@ def _convert_event_to_json(event: Event):
}
event_json['actions'] = actions_json
if event.content:
event_json['content'] = _session_util.encode_content(event.content)
event_json['content'] = event.content.model_dump(
exclude_none=True, mode='json'
)
if event.error_code:
event_json['error_code'] = event.error_code
if event.error_message:
@ -325,8 +330,8 @@ def _from_api_event(api_event: dict) -> Event:
event.turn_complete = api_event['eventMetadata'].get('turnComplete', None)
event.interrupted = api_event['eventMetadata'].get('interrupted', None)
event.branch = api_event['eventMetadata'].get('branch', None)
event.grounding_metadata = api_event['eventMetadata'].get(
'groundingMetadata', None
event.grounding_metadata = _session_util.decode_grounding_metadata(
api_event['eventMetadata'].get('groundingMetadata', None)
)
event.long_running_tool_ids = (
set(long_running_tool_ids_list) if long_running_tool_ids_list else None

View File

@ -246,32 +246,33 @@ async def test_append_event_bytes(service_type):
session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
test_content = types.Content(
role='user',
parts=[
types.Part.from_bytes(data=b'test_image_data', mime_type='image/png'),
],
)
test_grounding_metadata = types.GroundingMetadata(
search_entry_point=types.SearchEntryPoint(sdk_blob=b'test_sdk_blob')
)
event = Event(
invocation_id='invocation',
author='user',
content=types.Content(
role='user',
parts=[
types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png'
),
],
),
content=test_content,
grounding_metadata=test_grounding_metadata,
)
await session_service.append_event(session=session, event=event)
assert session.events[0].content.parts[0] == types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png'
)
assert session.events[0].content == test_content
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
events = session.events
assert len(events) == 1
assert events[0].content.parts[0] == types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png'
)
assert events[0].content == test_content
assert events[0].grounding_metadata == test_grounding_metadata
@pytest.mark.asyncio