fix: Make GroundingMetadata JSON serializable. Also use the same logic to simplify content serialization.

PiperOrigin-RevId: 764401248
This commit is contained in:
Shangjie Chen
2025-05-28 13:49:47 -07:00
committed by Copybara-Service
parent 7fc09b2c64
commit bf27f22a95
4 changed files with 45 additions and 37 deletions

View File

@@ -11,34 +11,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for session service."""
from __future__ import annotations
import base64
from typing import Any
from typing import Optional
from google.genai import types
def encode_content(content: types.Content):
"""Encodes a content object to a JSON dictionary."""
encoded_content = content.model_dump(exclude_none=True)
for p in encoded_content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64encode(
p["inline_data"]["data"]
).decode("utf-8")
return encoded_content
def decode_content(
content: Optional[dict[str, Any]],
) -> Optional[types.Content]:
"""Decodes a content object from a JSON dictionary."""
if not content:
return None
for p in content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
return types.Content.model_validate(content)
def decode_grounding_metadata(
grounding_metadata: Optional[dict[str, Any]],
) -> Optional[types.GroundingMetadata]:
"""Decodes a grounding metadata object from a JSON dictionary."""
if not grounding_metadata:
return None
return types.GroundingMetadata.model_validate(grounding_metadata)

View File

@@ -21,6 +21,7 @@ from typing import Any
from typing import Optional
import uuid
from google.genai import types
from sqlalchemy import Boolean
from sqlalchemy import delete
from sqlalchemy import Dialect
@@ -421,7 +422,9 @@ class DatabaseSessionService(BaseSessionService):
actions=e.actions,
timestamp=e.timestamp.timestamp(),
long_running_tool_ids=e.long_running_tool_ids,
grounding_metadata=e.grounding_metadata,
grounding_metadata=_session_util.decode_grounding_metadata(
e.grounding_metadata
),
partial=e.partial,
turn_complete=e.turn_complete,
error_code=e.error_code,
@@ -536,7 +539,6 @@ class DatabaseSessionService(BaseSessionService):
user_id=session.user_id,
timestamp=datetime.fromtimestamp(event.timestamp),
long_running_tool_ids=event.long_running_tool_ids,
grounding_metadata=event.grounding_metadata,
partial=event.partial,
turn_complete=event.turn_complete,
error_code=event.error_code,
@@ -544,7 +546,13 @@ class DatabaseSessionService(BaseSessionService):
interrupted=event.interrupted,
)
if event.content:
storage_event.content = _session_util.encode_content(event.content)
storage_event.content = event.content.model_dump(
exclude_none=True, mode="json"
)
if event.grounding_metadata:
storage_event.grounding_metadata = event.grounding_metadata.model_dump(
exclude_none=True, mode="json"
)
session_factory.add(storage_event)

View File

@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import logging
import re
@@ -18,6 +20,7 @@ from typing import Any
from typing import Optional
from dateutil import parser
from google.genai import types
from typing_extensions import override
from google import genai
@@ -256,7 +259,7 @@ def _convert_event_to_json(event: Event):
}
if event.grounding_metadata:
metadata_json['grounding_metadata'] = event.grounding_metadata.model_dump(
exclude_none=True
exclude_none=True, mode='json'
)
event_json = {
@@ -284,7 +287,9 @@ def _convert_event_to_json(event: Event):
}
event_json['actions'] = actions_json
if event.content:
event_json['content'] = _session_util.encode_content(event.content)
event_json['content'] = event.content.model_dump(
exclude_none=True, mode='json'
)
if event.error_code:
event_json['error_code'] = event.error_code
if event.error_message:
@@ -325,8 +330,8 @@ def _from_api_event(api_event: dict) -> Event:
event.turn_complete = api_event['eventMetadata'].get('turnComplete', None)
event.interrupted = api_event['eventMetadata'].get('interrupted', None)
event.branch = api_event['eventMetadata'].get('branch', None)
event.grounding_metadata = api_event['eventMetadata'].get(
'groundingMetadata', None
event.grounding_metadata = _session_util.decode_grounding_metadata(
api_event['eventMetadata'].get('groundingMetadata', None)
)
event.long_running_tool_ids = (
set(long_running_tool_ids_list) if long_running_tool_ids_list else None