mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
feat: Extract content encode/decode logic to a shared util and resolve issues with JSON serialization.
feat: Update key length for DB table to avoid key too long issue in mysql PiperOrigin-RevId: 753614879
This commit is contained in:
parent
b691904e57
commit
14933ba470
29
src/google/adk/sessions/_session_util.py
Normal file
29
src/google/adk/sessions/_session_util.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""Utility functions for session service."""
|
||||
|
||||
import base64
|
||||
from typing import Any, Optional
|
||||
|
||||
from google.genai import types
|
||||
|
||||
|
||||
def encode_content(content: types.Content):
|
||||
"""Encodes a content object to a JSON dictionary."""
|
||||
encoded_content = content.model_dump(exclude_none=True)
|
||||
for p in encoded_content["parts"]:
|
||||
if "inline_data" in p:
|
||||
p["inline_data"]["data"] = base64.b64encode(
|
||||
p["inline_data"]["data"]
|
||||
).decode("utf-8")
|
||||
return encoded_content
|
||||
|
||||
|
||||
def decode_content(
|
||||
content: Optional[dict[str, Any]],
|
||||
) -> Optional[types.Content]:
|
||||
"""Decodes a content object from a JSON dictionary."""
|
||||
if not content:
|
||||
return None
|
||||
for p in content["parts"]:
|
||||
if "inline_data" in p:
|
||||
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
|
||||
return types.Content.model_validate(content)
|
@ -11,8 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import copy
|
||||
from datetime import datetime
|
||||
import json
|
||||
@ -20,13 +18,13 @@ import logging
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
|
||||
from google.genai import types
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import Dialect
|
||||
from sqlalchemy import ForeignKeyConstraint
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.dialects import mysql
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
@ -48,6 +46,7 @@ from typing_extensions import override
|
||||
from tzlocal import get_localzone
|
||||
|
||||
from ..events.event import Event
|
||||
from . import _session_util
|
||||
from .base_session_service import BaseSessionService
|
||||
from .base_session_service import GetSessionConfig
|
||||
from .base_session_service import ListEventsResponse
|
||||
@ -58,6 +57,7 @@ from .state import State
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MAX_KEY_LENGTH = 128
|
||||
DEFAULT_MAX_VARCHAR_LENGTH = 256
|
||||
|
||||
|
||||
@ -72,15 +72,16 @@ class DynamicJSON(TypeDecorator):
|
||||
def load_dialect_impl(self, dialect: Dialect):
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(postgresql.JSONB)
|
||||
else:
|
||||
return dialect.type_descriptor(Text) # Default to Text for other dialects
|
||||
if dialect.name == "mysql":
|
||||
# Use LONGTEXT for MySQL to address the data too long issue
|
||||
return dialect.type_descriptor(mysql.LONGTEXT)
|
||||
return dialect.type_descriptor(Text) # Default to Text for other dialects
|
||||
|
||||
def process_bind_param(self, value, dialect: Dialect):
|
||||
if value is not None:
|
||||
if dialect.name == "postgresql":
|
||||
return value # JSONB handles dict directly
|
||||
else:
|
||||
return json.dumps(value) # Serialize to JSON string for TEXT
|
||||
return json.dumps(value) # Serialize to JSON string for TEXT
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect: Dialect):
|
||||
@ -104,13 +105,13 @@ class StorageSession(Base):
|
||||
__tablename__ = "sessions"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH),
|
||||
String(DEFAULT_MAX_KEY_LENGTH),
|
||||
primary_key=True,
|
||||
default=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
@ -139,16 +140,16 @@ class StorageEvent(Base):
|
||||
__tablename__ = "events"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
session_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
|
||||
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
||||
@ -209,7 +210,7 @@ class StorageAppState(Base):
|
||||
__tablename__ = "app_states"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
@ -224,13 +225,10 @@ class StorageUserState(Base):
|
||||
__tablename__ = "user_states"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
||||
)
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
@ -417,7 +415,7 @@ class DatabaseSessionService(BaseSessionService):
|
||||
author=e.author,
|
||||
branch=e.branch,
|
||||
invocation_id=e.invocation_id,
|
||||
content=_decode_content(e.content),
|
||||
content=_session_util.decode_content(e.content),
|
||||
actions=e.actions,
|
||||
timestamp=e.timestamp.timestamp(),
|
||||
long_running_tool_ids=e.long_running_tool_ids,
|
||||
@ -540,15 +538,7 @@ class DatabaseSessionService(BaseSessionService):
|
||||
interrupted=event.interrupted,
|
||||
)
|
||||
if event.content:
|
||||
encoded_content = event.content.model_dump(exclude_none=True)
|
||||
# Workaround for multimodal Content throwing JSON not serializable
|
||||
# error with SQLAlchemy.
|
||||
for p in encoded_content["parts"]:
|
||||
if "inline_data" in p:
|
||||
p["inline_data"]["data"] = (
|
||||
base64.b64encode(p["inline_data"]["data"]).decode("utf-8"),
|
||||
)
|
||||
storage_event.content = encoded_content
|
||||
storage_event.content = _session_util.encode_content(event.content)
|
||||
|
||||
sessionFactory.add(storage_event)
|
||||
|
||||
@ -608,14 +598,3 @@ def _merge_state(app_state, user_state, session_state):
|
||||
for key in user_state.keys():
|
||||
merged_state[State.USER_PREFIX + key] = user_state[key]
|
||||
return merged_state
|
||||
|
||||
|
||||
def _decode_content(
|
||||
content: Optional[dict[str, Any]],
|
||||
) -> Optional[types.Content]:
|
||||
if not content:
|
||||
return None
|
||||
for p in content["parts"]:
|
||||
if "inline_data" in p:
|
||||
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
|
||||
return types.Content.model_validate(content)
|
||||
|
@ -14,21 +14,23 @@
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from dateutil import parser
|
||||
from google import genai
|
||||
from typing_extensions import override
|
||||
|
||||
from ..events.event import Event
|
||||
from ..events.event_actions import EventActions
|
||||
from . import _session_util
|
||||
from .base_session_service import BaseSessionService
|
||||
from .base_session_service import GetSessionConfig
|
||||
from .base_session_service import ListEventsResponse
|
||||
from .base_session_service import ListSessionsResponse
|
||||
from .session import Session
|
||||
|
||||
|
||||
isoparse = parser.isoparse
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -289,7 +291,7 @@ def _convert_event_to_json(event: Event):
|
||||
}
|
||||
event_json['actions'] = actions_json
|
||||
if event.content:
|
||||
event_json['content'] = event.content.model_dump(exclude_none=True)
|
||||
event_json['content'] = _session_util.encode_content(event.content)
|
||||
if event.error_code:
|
||||
event_json['error_code'] = event.error_code
|
||||
if event.error_message:
|
||||
@ -316,7 +318,7 @@ def _from_api_event(api_event: dict) -> Event:
|
||||
invocation_id=api_event['invocationId'],
|
||||
author=api_event['author'],
|
||||
actions=event_actions,
|
||||
content=api_event.get('content', None),
|
||||
content=_session_util.decode_content(api_event.get('content', None)),
|
||||
timestamp=isoparse(api_event['timestamp']).timestamp(),
|
||||
error_code=api_event.get('errorCode', None),
|
||||
error_message=api_event.get('errorMessage', None),
|
||||
|
Loading…
Reference in New Issue
Block a user