mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 15:14:50 -06:00
fix: Make GroundingMetadata JSON serializable. Also use the same logic to simplify content serialization.
PiperOrigin-RevId: 764401248
This commit is contained in:
parent
7fc09b2c64
commit
bf27f22a95
@ -11,34 +11,28 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utility functions for session service."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
|
||||
|
||||
def encode_content(content: types.Content):
|
||||
"""Encodes a content object to a JSON dictionary."""
|
||||
encoded_content = content.model_dump(exclude_none=True)
|
||||
for p in encoded_content["parts"]:
|
||||
if "inline_data" in p:
|
||||
p["inline_data"]["data"] = base64.b64encode(
|
||||
p["inline_data"]["data"]
|
||||
).decode("utf-8")
|
||||
return encoded_content
|
||||
|
||||
|
||||
def decode_content(
|
||||
content: Optional[dict[str, Any]],
|
||||
) -> Optional[types.Content]:
|
||||
"""Decodes a content object from a JSON dictionary."""
|
||||
if not content:
|
||||
return None
|
||||
for p in content["parts"]:
|
||||
if "inline_data" in p:
|
||||
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
|
||||
return types.Content.model_validate(content)
|
||||
|
||||
|
||||
def decode_grounding_metadata(
|
||||
grounding_metadata: Optional[dict[str, Any]],
|
||||
) -> Optional[types.GroundingMetadata]:
|
||||
"""Decodes a grounding metadata object from a JSON dictionary."""
|
||||
if not grounding_metadata:
|
||||
return None
|
||||
return types.GroundingMetadata.model_validate(grounding_metadata)
|
||||
|
@ -21,6 +21,7 @@ from typing import Any
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from google.genai import types
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import Dialect
|
||||
@ -421,7 +422,9 @@ class DatabaseSessionService(BaseSessionService):
|
||||
actions=e.actions,
|
||||
timestamp=e.timestamp.timestamp(),
|
||||
long_running_tool_ids=e.long_running_tool_ids,
|
||||
grounding_metadata=e.grounding_metadata,
|
||||
grounding_metadata=_session_util.decode_grounding_metadata(
|
||||
e.grounding_metadata
|
||||
),
|
||||
partial=e.partial,
|
||||
turn_complete=e.turn_complete,
|
||||
error_code=e.error_code,
|
||||
@ -536,7 +539,6 @@ class DatabaseSessionService(BaseSessionService):
|
||||
user_id=session.user_id,
|
||||
timestamp=datetime.fromtimestamp(event.timestamp),
|
||||
long_running_tool_ids=event.long_running_tool_ids,
|
||||
grounding_metadata=event.grounding_metadata,
|
||||
partial=event.partial,
|
||||
turn_complete=event.turn_complete,
|
||||
error_code=event.error_code,
|
||||
@ -544,7 +546,13 @@ class DatabaseSessionService(BaseSessionService):
|
||||
interrupted=event.interrupted,
|
||||
)
|
||||
if event.content:
|
||||
storage_event.content = _session_util.encode_content(event.content)
|
||||
storage_event.content = event.content.model_dump(
|
||||
exclude_none=True, mode="json"
|
||||
)
|
||||
if event.grounding_metadata:
|
||||
storage_event.grounding_metadata = event.grounding_metadata.model_dump(
|
||||
exclude_none=True, mode="json"
|
||||
)
|
||||
|
||||
session_factory.add(storage_event)
|
||||
|
||||
|
@ -11,6 +11,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
@ -18,6 +20,7 @@ from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from dateutil import parser
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from google import genai
|
||||
@ -256,7 +259,7 @@ def _convert_event_to_json(event: Event):
|
||||
}
|
||||
if event.grounding_metadata:
|
||||
metadata_json['grounding_metadata'] = event.grounding_metadata.model_dump(
|
||||
exclude_none=True
|
||||
exclude_none=True, mode='json'
|
||||
)
|
||||
|
||||
event_json = {
|
||||
@ -284,7 +287,9 @@ def _convert_event_to_json(event: Event):
|
||||
}
|
||||
event_json['actions'] = actions_json
|
||||
if event.content:
|
||||
event_json['content'] = _session_util.encode_content(event.content)
|
||||
event_json['content'] = event.content.model_dump(
|
||||
exclude_none=True, mode='json'
|
||||
)
|
||||
if event.error_code:
|
||||
event_json['error_code'] = event.error_code
|
||||
if event.error_message:
|
||||
@ -325,8 +330,8 @@ def _from_api_event(api_event: dict) -> Event:
|
||||
event.turn_complete = api_event['eventMetadata'].get('turnComplete', None)
|
||||
event.interrupted = api_event['eventMetadata'].get('interrupted', None)
|
||||
event.branch = api_event['eventMetadata'].get('branch', None)
|
||||
event.grounding_metadata = api_event['eventMetadata'].get(
|
||||
'groundingMetadata', None
|
||||
event.grounding_metadata = _session_util.decode_grounding_metadata(
|
||||
api_event['eventMetadata'].get('groundingMetadata', None)
|
||||
)
|
||||
event.long_running_tool_ids = (
|
||||
set(long_running_tool_ids_list) if long_running_tool_ids_list else None
|
||||
|
@ -246,32 +246,33 @@ async def test_append_event_bytes(service_type):
|
||||
session = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
|
||||
test_content = types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part.from_bytes(data=b'test_image_data', mime_type='image/png'),
|
||||
],
|
||||
)
|
||||
test_grounding_metadata = types.GroundingMetadata(
|
||||
search_entry_point=types.SearchEntryPoint(sdk_blob=b'test_sdk_blob')
|
||||
)
|
||||
event = Event(
|
||||
invocation_id='invocation',
|
||||
author='user',
|
||||
content=types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part.from_bytes(
|
||||
data=b'test_image_data', mime_type='image/png'
|
||||
),
|
||||
],
|
||||
),
|
||||
content=test_content,
|
||||
grounding_metadata=test_grounding_metadata,
|
||||
)
|
||||
await session_service.append_event(session=session, event=event)
|
||||
|
||||
assert session.events[0].content.parts[0] == types.Part.from_bytes(
|
||||
data=b'test_image_data', mime_type='image/png'
|
||||
)
|
||||
assert session.events[0].content == test_content
|
||||
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
events = session.events
|
||||
assert len(events) == 1
|
||||
assert events[0].content.parts[0] == types.Part.from_bytes(
|
||||
data=b'test_image_data', mime_type='image/png'
|
||||
)
|
||||
assert events[0].content == test_content
|
||||
assert events[0].grounding_metadata == test_grounding_metadata
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
Loading…
Reference in New Issue
Block a user