feat: Extract content encode/decode logic to a shared util and resolve issues with JSON serialization.

feat: Update key length for DB table to avoid key too long issue in mysql

PiperOrigin-RevId: 753614879
This commit is contained in:
Shangjie Chen 2025-05-01 09:12:52 -07:00 committed by Copybara-Service
parent b691904e57
commit 14933ba470
3 changed files with 56 additions and 46 deletions

View File

@ -0,0 +1,29 @@
"""Utility functions for session service."""
import base64
from typing import Any, Optional
from google.genai import types
def encode_content(content: types.Content):
"""Encodes a content object to a JSON dictionary."""
encoded_content = content.model_dump(exclude_none=True)
for p in encoded_content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64encode(
p["inline_data"]["data"]
).decode("utf-8")
return encoded_content
def decode_content(
content: Optional[dict[str, Any]],
) -> Optional[types.Content]:
"""Decodes a content object from a JSON dictionary."""
if not content:
return None
for p in content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
return types.Content.model_validate(content)

View File

@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import base64
import copy import copy
from datetime import datetime from datetime import datetime
import json import json
@ -20,13 +18,13 @@ import logging
from typing import Any, Optional from typing import Any, Optional
import uuid import uuid
from google.genai import types
from sqlalchemy import Boolean from sqlalchemy import Boolean
from sqlalchemy import delete from sqlalchemy import delete
from sqlalchemy import Dialect from sqlalchemy import Dialect
from sqlalchemy import ForeignKeyConstraint from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy import Text from sqlalchemy import Text
from sqlalchemy.dialects import mysql
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import create_engine from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
@ -48,6 +46,7 @@ from typing_extensions import override
from tzlocal import get_localzone from tzlocal import get_localzone
from ..events.event import Event from ..events.event import Event
from . import _session_util
from .base_session_service import BaseSessionService from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig from .base_session_service import GetSessionConfig
from .base_session_service import ListEventsResponse from .base_session_service import ListEventsResponse
@ -58,6 +57,7 @@ from .state import State
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_MAX_KEY_LENGTH = 128
DEFAULT_MAX_VARCHAR_LENGTH = 256 DEFAULT_MAX_VARCHAR_LENGTH = 256
@ -72,14 +72,15 @@ class DynamicJSON(TypeDecorator):
def load_dialect_impl(self, dialect: Dialect): def load_dialect_impl(self, dialect: Dialect):
if dialect.name == "postgresql": if dialect.name == "postgresql":
return dialect.type_descriptor(postgresql.JSONB) return dialect.type_descriptor(postgresql.JSONB)
else: if dialect.name == "mysql":
# Use LONGTEXT for MySQL to address the data too long issue
return dialect.type_descriptor(mysql.LONGTEXT)
return dialect.type_descriptor(Text) # Default to Text for other dialects return dialect.type_descriptor(Text) # Default to Text for other dialects
def process_bind_param(self, value, dialect: Dialect): def process_bind_param(self, value, dialect: Dialect):
if value is not None: if value is not None:
if dialect.name == "postgresql": if dialect.name == "postgresql":
return value # JSONB handles dict directly return value # JSONB handles dict directly
else:
return json.dumps(value) # Serialize to JSON string for TEXT return json.dumps(value) # Serialize to JSON string for TEXT
return value return value
@ -104,13 +105,13 @@ class StorageSession(Base):
__tablename__ = "sessions" __tablename__ = "sessions"
app_name: Mapped[str] = mapped_column( app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
) )
user_id: Mapped[str] = mapped_column( user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
) )
id: Mapped[str] = mapped_column( id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), String(DEFAULT_MAX_KEY_LENGTH),
primary_key=True, primary_key=True,
default=lambda: str(uuid.uuid4()), default=lambda: str(uuid.uuid4()),
) )
@ -139,16 +140,16 @@ class StorageEvent(Base):
__tablename__ = "events" __tablename__ = "events"
id: Mapped[str] = mapped_column( id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
) )
app_name: Mapped[str] = mapped_column( app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
) )
user_id: Mapped[str] = mapped_column( user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
) )
session_id: Mapped[str] = mapped_column( session_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
) )
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
@ -209,7 +210,7 @@ class StorageAppState(Base):
__tablename__ = "app_states" __tablename__ = "app_states"
app_name: Mapped[str] = mapped_column( app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
) )
state: Mapped[MutableDict[str, Any]] = mapped_column( state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={} MutableDict.as_mutable(DynamicJSON), default={}
@ -224,13 +225,10 @@ class StorageUserState(Base):
__tablename__ = "user_states" __tablename__ = "user_states"
app_name: Mapped[str] = mapped_column( app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
) )
user_id: Mapped[str] = mapped_column( user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
) )
state: Mapped[MutableDict[str, Any]] = mapped_column( state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={} MutableDict.as_mutable(DynamicJSON), default={}
@ -417,7 +415,7 @@ class DatabaseSessionService(BaseSessionService):
author=e.author, author=e.author,
branch=e.branch, branch=e.branch,
invocation_id=e.invocation_id, invocation_id=e.invocation_id,
content=_decode_content(e.content), content=_session_util.decode_content(e.content),
actions=e.actions, actions=e.actions,
timestamp=e.timestamp.timestamp(), timestamp=e.timestamp.timestamp(),
long_running_tool_ids=e.long_running_tool_ids, long_running_tool_ids=e.long_running_tool_ids,
@ -540,15 +538,7 @@ class DatabaseSessionService(BaseSessionService):
interrupted=event.interrupted, interrupted=event.interrupted,
) )
if event.content: if event.content:
encoded_content = event.content.model_dump(exclude_none=True) storage_event.content = _session_util.encode_content(event.content)
# Workaround for multimodal Content throwing JSON not serializable
# error with SQLAlchemy.
for p in encoded_content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = (
base64.b64encode(p["inline_data"]["data"]).decode("utf-8"),
)
storage_event.content = encoded_content
sessionFactory.add(storage_event) sessionFactory.add(storage_event)
@ -608,14 +598,3 @@ def _merge_state(app_state, user_state, session_state):
for key in user_state.keys(): for key in user_state.keys():
merged_state[State.USER_PREFIX + key] = user_state[key] merged_state[State.USER_PREFIX + key] = user_state[key]
return merged_state return merged_state
def _decode_content(
content: Optional[dict[str, Any]],
) -> Optional[types.Content]:
if not content:
return None
for p in content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
return types.Content.model_validate(content)

View File

@ -14,21 +14,23 @@
import logging import logging
import re import re
import time import time
from typing import Any from typing import Any, Optional
from typing import Optional
from dateutil.parser import isoparse from dateutil import parser
from google import genai from google import genai
from typing_extensions import override from typing_extensions import override
from ..events.event import Event from ..events.event import Event
from ..events.event_actions import EventActions from ..events.event_actions import EventActions
from . import _session_util
from .base_session_service import BaseSessionService from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig from .base_session_service import GetSessionConfig
from .base_session_service import ListEventsResponse from .base_session_service import ListEventsResponse
from .base_session_service import ListSessionsResponse from .base_session_service import ListSessionsResponse
from .session import Session from .session import Session
isoparse = parser.isoparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -289,7 +291,7 @@ def _convert_event_to_json(event: Event):
} }
event_json['actions'] = actions_json event_json['actions'] = actions_json
if event.content: if event.content:
event_json['content'] = event.content.model_dump(exclude_none=True) event_json['content'] = _session_util.encode_content(event.content)
if event.error_code: if event.error_code:
event_json['error_code'] = event.error_code event_json['error_code'] = event.error_code
if event.error_message: if event.error_message:
@ -316,7 +318,7 @@ def _from_api_event(api_event: dict) -> Event:
invocation_id=api_event['invocationId'], invocation_id=api_event['invocationId'],
author=api_event['author'], author=api_event['author'],
actions=event_actions, actions=event_actions,
content=api_event.get('content', None), content=_session_util.decode_content(api_event.get('content', None)),
timestamp=isoparse(api_event['timestamp']).timestamp(), timestamp=isoparse(api_event['timestamp']).timestamp(),
error_code=api_event.get('errorCode', None), error_code=api_event.get('errorCode', None),
error_message=api_event.get('errorMessage', None), error_message=api_event.get('errorMessage', None),