diff --git a/src/google/adk/sessions/_session_util.py b/src/google/adk/sessions/_session_util.py new file mode 100644 index 0000000..4956b34 --- /dev/null +++ b/src/google/adk/sessions/_session_util.py @@ -0,0 +1,29 @@ +"""Utility functions for session service.""" + +import base64 +from typing import Any, Optional + +from google.genai import types + + +def encode_content(content: types.Content): + """Encodes a content object to a JSON dictionary.""" + encoded_content = content.model_dump(exclude_none=True) + for p in encoded_content["parts"]: + if "inline_data" in p: + p["inline_data"]["data"] = base64.b64encode( + p["inline_data"]["data"] + ).decode("utf-8") + return encoded_content + + +def decode_content( + content: Optional[dict[str, Any]], +) -> Optional[types.Content]: + """Decodes a content object from a JSON dictionary.""" + if not content: + return None + for p in content["parts"]: + if "inline_data" in p: + p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"]) + return types.Content.model_validate(content) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index c7caa19..e58c43d 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import base64 import copy from datetime import datetime import json @@ -20,13 +18,13 @@ import logging from typing import Any, Optional import uuid -from google.genai import types from sqlalchemy import Boolean from sqlalchemy import delete from sqlalchemy import Dialect from sqlalchemy import ForeignKeyConstraint from sqlalchemy import func from sqlalchemy import Text +from sqlalchemy.dialects import mysql from sqlalchemy.dialects import postgresql from sqlalchemy.engine import create_engine from sqlalchemy.engine import Engine @@ -48,6 +46,7 @@ from typing_extensions import override from tzlocal import get_localzone from ..events.event import Event +from . import _session_util from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig from .base_session_service import ListEventsResponse @@ -58,6 +57,7 @@ from .state import State logger = logging.getLogger(__name__) +DEFAULT_MAX_KEY_LENGTH = 128 DEFAULT_MAX_VARCHAR_LENGTH = 256 @@ -72,15 +72,16 @@ class DynamicJSON(TypeDecorator): def load_dialect_impl(self, dialect: Dialect): if dialect.name == "postgresql": return dialect.type_descriptor(postgresql.JSONB) - else: - return dialect.type_descriptor(Text) # Default to Text for other dialects + if dialect.name == "mysql": + # Use LONGTEXT for MySQL to address the data too long issue + return dialect.type_descriptor(mysql.LONGTEXT) + return dialect.type_descriptor(Text) # Default to Text for other dialects def process_bind_param(self, value, dialect: Dialect): if value is not None: if dialect.name == "postgresql": return value # JSONB handles dict directly - else: - return json.dumps(value) # Serialize to JSON string for TEXT + return json.dumps(value) # Serialize to JSON string for TEXT return value def process_result_value(self, value, dialect: Dialect): @@ -104,13 +105,13 @@ class StorageSession(Base): __tablename__ = "sessions" app_name: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True ) user_id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True ) id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True, default=lambda: str(uuid.uuid4()), ) @@ -139,16 +140,16 @@ class StorageEvent(Base): __tablename__ = "events" id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True ) app_name: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True ) user_id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True ) session_id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True ) invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) @@ -209,7 +210,7 @@ class StorageAppState(Base): __tablename__ = "app_states" app_name: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True ) state: Mapped[MutableDict[str, Any]] = mapped_column( MutableDict.as_mutable(DynamicJSON), default={} @@ -224,13 +225,10 @@ class StorageUserState(Base): __tablename__ = "user_states" app_name: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True ) user_id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True - ) - state: Mapped[MutableDict[str, Any]] = mapped_column( - MutableDict.as_mutable(DynamicJSON), default={} + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True ) state: Mapped[MutableDict[str, Any]] = mapped_column( MutableDict.as_mutable(DynamicJSON), default={} @@ -417,7 +415,7 @@ class DatabaseSessionService(BaseSessionService): author=e.author, branch=e.branch, invocation_id=e.invocation_id, - content=_decode_content(e.content), + content=_session_util.decode_content(e.content), actions=e.actions, timestamp=e.timestamp.timestamp(), long_running_tool_ids=e.long_running_tool_ids, @@ -540,15 +538,7 @@ class DatabaseSessionService(BaseSessionService): interrupted=event.interrupted, ) if event.content: - encoded_content = event.content.model_dump(exclude_none=True) - # Workaround for multimodal Content throwing JSON not serializable - # error with SQLAlchemy. - for p in encoded_content["parts"]: - if "inline_data" in p: - p["inline_data"]["data"] = ( - base64.b64encode(p["inline_data"]["data"]).decode("utf-8"), - ) - storage_event.content = encoded_content + storage_event.content = _session_util.encode_content(event.content) sessionFactory.add(storage_event) @@ -608,14 +598,3 @@ def _merge_state(app_state, user_state, session_state): for key in user_state.keys(): merged_state[State.USER_PREFIX + key] = user_state[key] return merged_state - - -def _decode_content( - content: Optional[dict[str, Any]], -) -> Optional[types.Content]: - if not content: - return None - for p in content["parts"]: - if "inline_data" in p: - p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0]) - return types.Content.model_validate(content) diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 5ec45c4..8d6fa75 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -14,21 +14,23 @@ import logging import re import time -from typing import Any -from typing import Optional +from typing import Any, Optional -from dateutil.parser import isoparse +from dateutil import parser from google import genai from typing_extensions import override from ..events.event import Event from ..events.event_actions import EventActions +from . import _session_util from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig from .base_session_service import ListEventsResponse from .base_session_service import ListSessionsResponse from .session import Session + +isoparse = parser.isoparse logger = logging.getLogger(__name__) @@ -289,7 +291,7 @@ def _convert_event_to_json(event: Event): } event_json['actions'] = actions_json if event.content: - event_json['content'] = event.content.model_dump(exclude_none=True) + event_json['content'] = _session_util.encode_content(event.content) if event.error_code: event_json['error_code'] = event.error_code if event.error_message: @@ -316,7 +318,7 @@ def _from_api_event(api_event: dict) -> Event: invocation_id=api_event['invocationId'], author=api_event['author'], actions=event_actions, - content=api_event.get('content', None), + content=_session_util.decode_content(api_event.get('content', None)), timestamp=isoparse(api_event['timestamp']).timestamp(), error_code=api_event.get('errorCode', None), error_message=api_event.get('errorMessage', None),