feat: Extract content encode/decode logic to a shared util and resolve issues with JSON serialization.

feat: Update key length for DB table to avoid key too long issue in mysql

PiperOrigin-RevId: 753614879
This commit is contained in:
Shangjie Chen 2025-05-01 09:12:52 -07:00 committed by Copybara-Service
parent b691904e57
commit 14933ba470
3 changed files with 56 additions and 46 deletions

View File

@ -0,0 +1,29 @@
"""Utility functions for session service."""
import base64
from typing import Any, Optional
from google.genai import types
def encode_content(content: types.Content):
"""Encodes a content object to a JSON dictionary."""
encoded_content = content.model_dump(exclude_none=True)
for p in encoded_content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64encode(
p["inline_data"]["data"]
).decode("utf-8")
return encoded_content
def decode_content(
content: Optional[dict[str, Any]],
) -> Optional[types.Content]:
"""Decodes a content object from a JSON dictionary."""
if not content:
return None
for p in content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
return types.Content.model_validate(content)

View File

@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import copy
from datetime import datetime
import json
@ -20,13 +18,13 @@ import logging
from typing import Any, Optional
import uuid
from google.genai import types
from sqlalchemy import Boolean
from sqlalchemy import delete
from sqlalchemy import Dialect
from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import func
from sqlalchemy import Text
from sqlalchemy.dialects import mysql
from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
@ -48,6 +46,7 @@ from typing_extensions import override
from tzlocal import get_localzone
from ..events.event import Event
from . import _session_util
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListEventsResponse
@ -58,6 +57,7 @@ from .state import State
logger = logging.getLogger(__name__)
DEFAULT_MAX_KEY_LENGTH = 128
DEFAULT_MAX_VARCHAR_LENGTH = 256
@ -72,15 +72,16 @@ class DynamicJSON(TypeDecorator):
def load_dialect_impl(self, dialect: Dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(postgresql.JSONB)
else:
return dialect.type_descriptor(Text) # Default to Text for other dialects
if dialect.name == "mysql":
# Use LONGTEXT for MySQL to address the data too long issue
return dialect.type_descriptor(mysql.LONGTEXT)
return dialect.type_descriptor(Text) # Default to Text for other dialects
def process_bind_param(self, value, dialect: Dialect):
if value is not None:
if dialect.name == "postgresql":
return value # JSONB handles dict directly
else:
return json.dumps(value) # Serialize to JSON string for TEXT
return json.dumps(value) # Serialize to JSON string for TEXT
return value
def process_result_value(self, value, dialect: Dialect):
@ -104,13 +105,13 @@ class StorageSession(Base):
__tablename__ = "sessions"
app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH),
String(DEFAULT_MAX_KEY_LENGTH),
primary_key=True,
default=lambda: str(uuid.uuid4()),
)
@ -139,16 +140,16 @@ class StorageEvent(Base):
__tablename__ = "events"
id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
session_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
@ -209,7 +210,7 @@ class StorageAppState(Base):
__tablename__ = "app_states"
app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
@ -224,13 +225,10 @@ class StorageUserState(Base):
__tablename__ = "user_states"
app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
@ -417,7 +415,7 @@ class DatabaseSessionService(BaseSessionService):
author=e.author,
branch=e.branch,
invocation_id=e.invocation_id,
content=_decode_content(e.content),
content=_session_util.decode_content(e.content),
actions=e.actions,
timestamp=e.timestamp.timestamp(),
long_running_tool_ids=e.long_running_tool_ids,
@ -540,15 +538,7 @@ class DatabaseSessionService(BaseSessionService):
interrupted=event.interrupted,
)
if event.content:
encoded_content = event.content.model_dump(exclude_none=True)
# Workaround for multimodal Content throwing JSON not serializable
# error with SQLAlchemy.
for p in encoded_content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = (
base64.b64encode(p["inline_data"]["data"]).decode("utf-8"),
)
storage_event.content = encoded_content
storage_event.content = _session_util.encode_content(event.content)
sessionFactory.add(storage_event)
@ -608,14 +598,3 @@ def _merge_state(app_state, user_state, session_state):
for key in user_state.keys():
merged_state[State.USER_PREFIX + key] = user_state[key]
return merged_state
def _decode_content(
content: Optional[dict[str, Any]],
) -> Optional[types.Content]:
if not content:
return None
for p in content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
return types.Content.model_validate(content)

View File

@ -14,21 +14,23 @@
import logging
import re
import time
from typing import Any
from typing import Optional
from typing import Any, Optional
from dateutil.parser import isoparse
from dateutil import parser
from google import genai
from typing_extensions import override
from ..events.event import Event
from ..events.event_actions import EventActions
from . import _session_util
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListEventsResponse
from .base_session_service import ListSessionsResponse
from .session import Session
isoparse = parser.isoparse
logger = logging.getLogger(__name__)
@ -289,7 +291,7 @@ def _convert_event_to_json(event: Event):
}
event_json['actions'] = actions_json
if event.content:
event_json['content'] = event.content.model_dump(exclude_none=True)
event_json['content'] = _session_util.encode_content(event.content)
if event.error_code:
event_json['error_code'] = event.error_code
if event.error_message:
@ -316,7 +318,7 @@ def _from_api_event(api_event: dict) -> Event:
invocation_id=api_event['invocationId'],
author=api_event['author'],
actions=event_actions,
content=api_event.get('content', None),
content=_session_util.decode_content(api_event.get('content', None)),
timestamp=isoparse(api_event['timestamp']).timestamp(),
error_code=api_event.get('errorCode', None),
error_message=api_event.get('errorMessage', None),