mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
feat: Extract content encode/decode logic to a shared util and resolve issues with JSON serialization.
feat: Update key length for DB table to avoid key too long issue in mysql PiperOrigin-RevId: 753614879
This commit is contained in:
parent
b691904e57
commit
14933ba470
29
src/google/adk/sessions/_session_util.py
Normal file
29
src/google/adk/sessions/_session_util.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
"""Utility functions for session service."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from google.genai import types
|
||||||
|
|
||||||
|
|
||||||
|
def encode_content(content: types.Content):
|
||||||
|
"""Encodes a content object to a JSON dictionary."""
|
||||||
|
encoded_content = content.model_dump(exclude_none=True)
|
||||||
|
for p in encoded_content["parts"]:
|
||||||
|
if "inline_data" in p:
|
||||||
|
p["inline_data"]["data"] = base64.b64encode(
|
||||||
|
p["inline_data"]["data"]
|
||||||
|
).decode("utf-8")
|
||||||
|
return encoded_content
|
||||||
|
|
||||||
|
|
||||||
|
def decode_content(
|
||||||
|
content: Optional[dict[str, Any]],
|
||||||
|
) -> Optional[types.Content]:
|
||||||
|
"""Decodes a content object from a JSON dictionary."""
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
for p in content["parts"]:
|
||||||
|
if "inline_data" in p:
|
||||||
|
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
|
||||||
|
return types.Content.model_validate(content)
|
@ -11,8 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import base64
|
|
||||||
import copy
|
import copy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import json
|
import json
|
||||||
@ -20,13 +18,13 @@ import logging
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from google.genai import types
|
|
||||||
from sqlalchemy import Boolean
|
from sqlalchemy import Boolean
|
||||||
from sqlalchemy import delete
|
from sqlalchemy import delete
|
||||||
from sqlalchemy import Dialect
|
from sqlalchemy import Dialect
|
||||||
from sqlalchemy import ForeignKeyConstraint
|
from sqlalchemy import ForeignKeyConstraint
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy import Text
|
from sqlalchemy import Text
|
||||||
|
from sqlalchemy.dialects import mysql
|
||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
from sqlalchemy.engine import create_engine
|
from sqlalchemy.engine import create_engine
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
@ -48,6 +46,7 @@ from typing_extensions import override
|
|||||||
from tzlocal import get_localzone
|
from tzlocal import get_localzone
|
||||||
|
|
||||||
from ..events.event import Event
|
from ..events.event import Event
|
||||||
|
from . import _session_util
|
||||||
from .base_session_service import BaseSessionService
|
from .base_session_service import BaseSessionService
|
||||||
from .base_session_service import GetSessionConfig
|
from .base_session_service import GetSessionConfig
|
||||||
from .base_session_service import ListEventsResponse
|
from .base_session_service import ListEventsResponse
|
||||||
@ -58,6 +57,7 @@ from .state import State
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_MAX_KEY_LENGTH = 128
|
||||||
DEFAULT_MAX_VARCHAR_LENGTH = 256
|
DEFAULT_MAX_VARCHAR_LENGTH = 256
|
||||||
|
|
||||||
|
|
||||||
@ -72,15 +72,16 @@ class DynamicJSON(TypeDecorator):
|
|||||||
def load_dialect_impl(self, dialect: Dialect):
|
def load_dialect_impl(self, dialect: Dialect):
|
||||||
if dialect.name == "postgresql":
|
if dialect.name == "postgresql":
|
||||||
return dialect.type_descriptor(postgresql.JSONB)
|
return dialect.type_descriptor(postgresql.JSONB)
|
||||||
else:
|
if dialect.name == "mysql":
|
||||||
return dialect.type_descriptor(Text) # Default to Text for other dialects
|
# Use LONGTEXT for MySQL to address the data too long issue
|
||||||
|
return dialect.type_descriptor(mysql.LONGTEXT)
|
||||||
|
return dialect.type_descriptor(Text) # Default to Text for other dialects
|
||||||
|
|
||||||
def process_bind_param(self, value, dialect: Dialect):
|
def process_bind_param(self, value, dialect: Dialect):
|
||||||
if value is not None:
|
if value is not None:
|
||||||
if dialect.name == "postgresql":
|
if dialect.name == "postgresql":
|
||||||
return value # JSONB handles dict directly
|
return value # JSONB handles dict directly
|
||||||
else:
|
return json.dumps(value) # Serialize to JSON string for TEXT
|
||||||
return json.dumps(value) # Serialize to JSON string for TEXT
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def process_result_value(self, value, dialect: Dialect):
|
def process_result_value(self, value, dialect: Dialect):
|
||||||
@ -104,13 +105,13 @@ class StorageSession(Base):
|
|||||||
__tablename__ = "sessions"
|
__tablename__ = "sessions"
|
||||||
|
|
||||||
app_name: Mapped[str] = mapped_column(
|
app_name: Mapped[str] = mapped_column(
|
||||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
)
|
)
|
||||||
user_id: Mapped[str] = mapped_column(
|
user_id: Mapped[str] = mapped_column(
|
||||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
)
|
)
|
||||||
id: Mapped[str] = mapped_column(
|
id: Mapped[str] = mapped_column(
|
||||||
String(DEFAULT_MAX_VARCHAR_LENGTH),
|
String(DEFAULT_MAX_KEY_LENGTH),
|
||||||
primary_key=True,
|
primary_key=True,
|
||||||
default=lambda: str(uuid.uuid4()),
|
default=lambda: str(uuid.uuid4()),
|
||||||
)
|
)
|
||||||
@ -139,16 +140,16 @@ class StorageEvent(Base):
|
|||||||
__tablename__ = "events"
|
__tablename__ = "events"
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
id: Mapped[str] = mapped_column(
|
||||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
)
|
)
|
||||||
app_name: Mapped[str] = mapped_column(
|
app_name: Mapped[str] = mapped_column(
|
||||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
)
|
)
|
||||||
user_id: Mapped[str] = mapped_column(
|
user_id: Mapped[str] = mapped_column(
|
||||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
)
|
)
|
||||||
session_id: Mapped[str] = mapped_column(
|
session_id: Mapped[str] = mapped_column(
|
||||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
)
|
)
|
||||||
|
|
||||||
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
||||||
@ -209,7 +210,7 @@ class StorageAppState(Base):
|
|||||||
__tablename__ = "app_states"
|
__tablename__ = "app_states"
|
||||||
|
|
||||||
app_name: Mapped[str] = mapped_column(
|
app_name: Mapped[str] = mapped_column(
|
||||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
)
|
)
|
||||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||||
MutableDict.as_mutable(DynamicJSON), default={}
|
MutableDict.as_mutable(DynamicJSON), default={}
|
||||||
@ -224,13 +225,10 @@ class StorageUserState(Base):
|
|||||||
__tablename__ = "user_states"
|
__tablename__ = "user_states"
|
||||||
|
|
||||||
app_name: Mapped[str] = mapped_column(
|
app_name: Mapped[str] = mapped_column(
|
||||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
)
|
)
|
||||||
user_id: Mapped[str] = mapped_column(
|
user_id: Mapped[str] = mapped_column(
|
||||||
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
)
|
|
||||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
|
||||||
MutableDict.as_mutable(DynamicJSON), default={}
|
|
||||||
)
|
)
|
||||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||||
MutableDict.as_mutable(DynamicJSON), default={}
|
MutableDict.as_mutable(DynamicJSON), default={}
|
||||||
@ -417,7 +415,7 @@ class DatabaseSessionService(BaseSessionService):
|
|||||||
author=e.author,
|
author=e.author,
|
||||||
branch=e.branch,
|
branch=e.branch,
|
||||||
invocation_id=e.invocation_id,
|
invocation_id=e.invocation_id,
|
||||||
content=_decode_content(e.content),
|
content=_session_util.decode_content(e.content),
|
||||||
actions=e.actions,
|
actions=e.actions,
|
||||||
timestamp=e.timestamp.timestamp(),
|
timestamp=e.timestamp.timestamp(),
|
||||||
long_running_tool_ids=e.long_running_tool_ids,
|
long_running_tool_ids=e.long_running_tool_ids,
|
||||||
@ -540,15 +538,7 @@ class DatabaseSessionService(BaseSessionService):
|
|||||||
interrupted=event.interrupted,
|
interrupted=event.interrupted,
|
||||||
)
|
)
|
||||||
if event.content:
|
if event.content:
|
||||||
encoded_content = event.content.model_dump(exclude_none=True)
|
storage_event.content = _session_util.encode_content(event.content)
|
||||||
# Workaround for multimodal Content throwing JSON not serializable
|
|
||||||
# error with SQLAlchemy.
|
|
||||||
for p in encoded_content["parts"]:
|
|
||||||
if "inline_data" in p:
|
|
||||||
p["inline_data"]["data"] = (
|
|
||||||
base64.b64encode(p["inline_data"]["data"]).decode("utf-8"),
|
|
||||||
)
|
|
||||||
storage_event.content = encoded_content
|
|
||||||
|
|
||||||
sessionFactory.add(storage_event)
|
sessionFactory.add(storage_event)
|
||||||
|
|
||||||
@ -608,14 +598,3 @@ def _merge_state(app_state, user_state, session_state):
|
|||||||
for key in user_state.keys():
|
for key in user_state.keys():
|
||||||
merged_state[State.USER_PREFIX + key] = user_state[key]
|
merged_state[State.USER_PREFIX + key] = user_state[key]
|
||||||
return merged_state
|
return merged_state
|
||||||
|
|
||||||
|
|
||||||
def _decode_content(
|
|
||||||
content: Optional[dict[str, Any]],
|
|
||||||
) -> Optional[types.Content]:
|
|
||||||
if not content:
|
|
||||||
return None
|
|
||||||
for p in content["parts"]:
|
|
||||||
if "inline_data" in p:
|
|
||||||
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
|
|
||||||
return types.Content.model_validate(content)
|
|
||||||
|
@ -14,21 +14,23 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from dateutil.parser import isoparse
|
from dateutil import parser
|
||||||
from google import genai
|
from google import genai
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..events.event import Event
|
from ..events.event import Event
|
||||||
from ..events.event_actions import EventActions
|
from ..events.event_actions import EventActions
|
||||||
|
from . import _session_util
|
||||||
from .base_session_service import BaseSessionService
|
from .base_session_service import BaseSessionService
|
||||||
from .base_session_service import GetSessionConfig
|
from .base_session_service import GetSessionConfig
|
||||||
from .base_session_service import ListEventsResponse
|
from .base_session_service import ListEventsResponse
|
||||||
from .base_session_service import ListSessionsResponse
|
from .base_session_service import ListSessionsResponse
|
||||||
from .session import Session
|
from .session import Session
|
||||||
|
|
||||||
|
|
||||||
|
isoparse = parser.isoparse
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -289,7 +291,7 @@ def _convert_event_to_json(event: Event):
|
|||||||
}
|
}
|
||||||
event_json['actions'] = actions_json
|
event_json['actions'] = actions_json
|
||||||
if event.content:
|
if event.content:
|
||||||
event_json['content'] = event.content.model_dump(exclude_none=True)
|
event_json['content'] = _session_util.encode_content(event.content)
|
||||||
if event.error_code:
|
if event.error_code:
|
||||||
event_json['error_code'] = event.error_code
|
event_json['error_code'] = event.error_code
|
||||||
if event.error_message:
|
if event.error_message:
|
||||||
@ -316,7 +318,7 @@ def _from_api_event(api_event: dict) -> Event:
|
|||||||
invocation_id=api_event['invocationId'],
|
invocation_id=api_event['invocationId'],
|
||||||
author=api_event['author'],
|
author=api_event['author'],
|
||||||
actions=event_actions,
|
actions=event_actions,
|
||||||
content=api_event.get('content', None),
|
content=_session_util.decode_content(api_event.get('content', None)),
|
||||||
timestamp=isoparse(api_event['timestamp']).timestamp(),
|
timestamp=isoparse(api_event['timestamp']).timestamp(),
|
||||||
error_code=api_event.get('errorCode', None),
|
error_code=api_event.get('errorCode', None),
|
||||||
error_message=api_event.get('errorMessage', None),
|
error_message=api_event.get('errorMessage', None),
|
||||||
|
Loading…
Reference in New Issue
Block a user