mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-16 04:02:55 -06:00
fix: Make GroundingMetadata JSON serializable. Also use the same logic to simplify content serialization.
PiperOrigin-RevId: 764401248
This commit is contained in:
parent
7fc09b2c64
commit
bf27f22a95
@ -11,34 +11,28 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Utility functions for session service."""
|
"""Utility functions for session service."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
|
||||||
|
|
||||||
def encode_content(content: types.Content):
|
|
||||||
"""Encodes a content object to a JSON dictionary."""
|
|
||||||
encoded_content = content.model_dump(exclude_none=True)
|
|
||||||
for p in encoded_content["parts"]:
|
|
||||||
if "inline_data" in p:
|
|
||||||
p["inline_data"]["data"] = base64.b64encode(
|
|
||||||
p["inline_data"]["data"]
|
|
||||||
).decode("utf-8")
|
|
||||||
return encoded_content
|
|
||||||
|
|
||||||
|
|
||||||
def decode_content(
|
def decode_content(
|
||||||
content: Optional[dict[str, Any]],
|
content: Optional[dict[str, Any]],
|
||||||
) -> Optional[types.Content]:
|
) -> Optional[types.Content]:
|
||||||
"""Decodes a content object from a JSON dictionary."""
|
"""Decodes a content object from a JSON dictionary."""
|
||||||
if not content:
|
if not content:
|
||||||
return None
|
return None
|
||||||
for p in content["parts"]:
|
|
||||||
if "inline_data" in p:
|
|
||||||
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
|
|
||||||
return types.Content.model_validate(content)
|
return types.Content.model_validate(content)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_grounding_metadata(
|
||||||
|
grounding_metadata: Optional[dict[str, Any]],
|
||||||
|
) -> Optional[types.GroundingMetadata]:
|
||||||
|
"""Decodes a grounding metadata object from a JSON dictionary."""
|
||||||
|
if not grounding_metadata:
|
||||||
|
return None
|
||||||
|
return types.GroundingMetadata.model_validate(grounding_metadata)
|
||||||
|
@ -21,6 +21,7 @@ from typing import Any
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
from google.genai import types
|
||||||
from sqlalchemy import Boolean
|
from sqlalchemy import Boolean
|
||||||
from sqlalchemy import delete
|
from sqlalchemy import delete
|
||||||
from sqlalchemy import Dialect
|
from sqlalchemy import Dialect
|
||||||
@ -421,7 +422,9 @@ class DatabaseSessionService(BaseSessionService):
|
|||||||
actions=e.actions,
|
actions=e.actions,
|
||||||
timestamp=e.timestamp.timestamp(),
|
timestamp=e.timestamp.timestamp(),
|
||||||
long_running_tool_ids=e.long_running_tool_ids,
|
long_running_tool_ids=e.long_running_tool_ids,
|
||||||
grounding_metadata=e.grounding_metadata,
|
grounding_metadata=_session_util.decode_grounding_metadata(
|
||||||
|
e.grounding_metadata
|
||||||
|
),
|
||||||
partial=e.partial,
|
partial=e.partial,
|
||||||
turn_complete=e.turn_complete,
|
turn_complete=e.turn_complete,
|
||||||
error_code=e.error_code,
|
error_code=e.error_code,
|
||||||
@ -536,7 +539,6 @@ class DatabaseSessionService(BaseSessionService):
|
|||||||
user_id=session.user_id,
|
user_id=session.user_id,
|
||||||
timestamp=datetime.fromtimestamp(event.timestamp),
|
timestamp=datetime.fromtimestamp(event.timestamp),
|
||||||
long_running_tool_ids=event.long_running_tool_ids,
|
long_running_tool_ids=event.long_running_tool_ids,
|
||||||
grounding_metadata=event.grounding_metadata,
|
|
||||||
partial=event.partial,
|
partial=event.partial,
|
||||||
turn_complete=event.turn_complete,
|
turn_complete=event.turn_complete,
|
||||||
error_code=event.error_code,
|
error_code=event.error_code,
|
||||||
@ -544,7 +546,13 @@ class DatabaseSessionService(BaseSessionService):
|
|||||||
interrupted=event.interrupted,
|
interrupted=event.interrupted,
|
||||||
)
|
)
|
||||||
if event.content:
|
if event.content:
|
||||||
storage_event.content = _session_util.encode_content(event.content)
|
storage_event.content = event.content.model_dump(
|
||||||
|
exclude_none=True, mode="json"
|
||||||
|
)
|
||||||
|
if event.grounding_metadata:
|
||||||
|
storage_event.grounding_metadata = event.grounding_metadata.model_dump(
|
||||||
|
exclude_none=True, mode="json"
|
||||||
|
)
|
||||||
|
|
||||||
session_factory.add(storage_event)
|
session_factory.add(storage_event)
|
||||||
|
|
||||||
|
@ -11,6 +11,8 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
@ -18,6 +20,7 @@ from typing import Any
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from dateutil import parser
|
from dateutil import parser
|
||||||
|
from google.genai import types
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
@ -256,7 +259,7 @@ def _convert_event_to_json(event: Event):
|
|||||||
}
|
}
|
||||||
if event.grounding_metadata:
|
if event.grounding_metadata:
|
||||||
metadata_json['grounding_metadata'] = event.grounding_metadata.model_dump(
|
metadata_json['grounding_metadata'] = event.grounding_metadata.model_dump(
|
||||||
exclude_none=True
|
exclude_none=True, mode='json'
|
||||||
)
|
)
|
||||||
|
|
||||||
event_json = {
|
event_json = {
|
||||||
@ -284,7 +287,9 @@ def _convert_event_to_json(event: Event):
|
|||||||
}
|
}
|
||||||
event_json['actions'] = actions_json
|
event_json['actions'] = actions_json
|
||||||
if event.content:
|
if event.content:
|
||||||
event_json['content'] = _session_util.encode_content(event.content)
|
event_json['content'] = event.content.model_dump(
|
||||||
|
exclude_none=True, mode='json'
|
||||||
|
)
|
||||||
if event.error_code:
|
if event.error_code:
|
||||||
event_json['error_code'] = event.error_code
|
event_json['error_code'] = event.error_code
|
||||||
if event.error_message:
|
if event.error_message:
|
||||||
@ -325,8 +330,8 @@ def _from_api_event(api_event: dict) -> Event:
|
|||||||
event.turn_complete = api_event['eventMetadata'].get('turnComplete', None)
|
event.turn_complete = api_event['eventMetadata'].get('turnComplete', None)
|
||||||
event.interrupted = api_event['eventMetadata'].get('interrupted', None)
|
event.interrupted = api_event['eventMetadata'].get('interrupted', None)
|
||||||
event.branch = api_event['eventMetadata'].get('branch', None)
|
event.branch = api_event['eventMetadata'].get('branch', None)
|
||||||
event.grounding_metadata = api_event['eventMetadata'].get(
|
event.grounding_metadata = _session_util.decode_grounding_metadata(
|
||||||
'groundingMetadata', None
|
api_event['eventMetadata'].get('groundingMetadata', None)
|
||||||
)
|
)
|
||||||
event.long_running_tool_ids = (
|
event.long_running_tool_ids = (
|
||||||
set(long_running_tool_ids_list) if long_running_tool_ids_list else None
|
set(long_running_tool_ids_list) if long_running_tool_ids_list else None
|
||||||
|
@ -246,32 +246,33 @@ async def test_append_event_bytes(service_type):
|
|||||||
session = await session_service.create_session(
|
session = await session_service.create_session(
|
||||||
app_name=app_name, user_id=user_id
|
app_name=app_name, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
test_content = types.Content(
|
||||||
|
role='user',
|
||||||
|
parts=[
|
||||||
|
types.Part.from_bytes(data=b'test_image_data', mime_type='image/png'),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
test_grounding_metadata = types.GroundingMetadata(
|
||||||
|
search_entry_point=types.SearchEntryPoint(sdk_blob=b'test_sdk_blob')
|
||||||
|
)
|
||||||
event = Event(
|
event = Event(
|
||||||
invocation_id='invocation',
|
invocation_id='invocation',
|
||||||
author='user',
|
author='user',
|
||||||
content=types.Content(
|
content=test_content,
|
||||||
role='user',
|
grounding_metadata=test_grounding_metadata,
|
||||||
parts=[
|
|
||||||
types.Part.from_bytes(
|
|
||||||
data=b'test_image_data', mime_type='image/png'
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
await session_service.append_event(session=session, event=event)
|
await session_service.append_event(session=session, event=event)
|
||||||
|
|
||||||
assert session.events[0].content.parts[0] == types.Part.from_bytes(
|
assert session.events[0].content == test_content
|
||||||
data=b'test_image_data', mime_type='image/png'
|
|
||||||
)
|
|
||||||
|
|
||||||
session = await session_service.get_session(
|
session = await session_service.get_session(
|
||||||
app_name=app_name, user_id=user_id, session_id=session.id
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
)
|
)
|
||||||
events = session.events
|
events = session.events
|
||||||
assert len(events) == 1
|
assert len(events) == 1
|
||||||
assert events[0].content.parts[0] == types.Part.from_bytes(
|
assert events[0].content == test_content
|
||||||
data=b'test_image_data', mime_type='image/png'
|
assert events[0].grounding_metadata == test_grounding_metadata
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
Loading…
Reference in New Issue
Block a user