From 2a9ddec7e323dc9c9c3c7b456b835d41bff45353 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Tue, 29 Apr 2025 16:07:21 -0700 Subject: [PATCH] Set the max size of strings in database columns. PiperOrigin-RevId: 752918808 --- .../adk/sessions/database_session_service.py | 65 ++++++++++++++----- 1 file changed, 49 insertions(+), 16 deletions(-) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 93e66f7..9bfa3cc 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -58,6 +58,8 @@ from .state import State logger = logging.getLogger(__name__) +DEFAULT_MAX_VARCHAR_LENGTH = 256 + class DynamicJSON(TypeDecorator): """A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON @@ -92,17 +94,25 @@ class DynamicJSON(TypeDecorator): class Base(DeclarativeBase): """Base class for database tables.""" + pass class StorageSession(Base): """Represents a session stored in the database.""" + __tablename__ = "sessions" - app_name: Mapped[str] = mapped_column(String, primary_key=True) - user_id: Mapped[str] = mapped_column(String, primary_key=True) + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) id: Mapped[str] = mapped_column( - String, primary_key=True, default=lambda: str(uuid.uuid4()) + String(DEFAULT_MAX_VARCHAR_LENGTH), + primary_key=True, + default=lambda: str(uuid.uuid4()), ) state: Mapped[MutableDict[str, Any]] = mapped_column( @@ -125,16 +135,27 @@ class StorageSession(Base): class StorageEvent(Base): """Represents an event stored in the database.""" + __tablename__ = "events" - id: Mapped[str] = mapped_column(String, primary_key=True) - app_name: Mapped[str] = mapped_column(String, primary_key=True) - user_id: Mapped[str] = mapped_column(String, primary_key=True) - session_id: Mapped[str] = mapped_column(String, primary_key=True) + id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) + session_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) - invocation_id: Mapped[str] = mapped_column(String) - author: Mapped[str] = mapped_column(String) - branch: Mapped[str] = mapped_column(String, nullable=True) + invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + branch: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now()) content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType) @@ -147,8 +168,10 @@ class StorageEvent(Base): ) partial: Mapped[bool] = mapped_column(Boolean, nullable=True) turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) - error_code: Mapped[str] = mapped_column(String, nullable=True) - error_message: Mapped[str] = mapped_column(String, nullable=True) + error_code: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) + error_message: Mapped[str] = mapped_column(String(1024), nullable=True) interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) storage_session: Mapped[StorageSession] = relationship( @@ -182,9 +205,12 @@ class StorageEvent(Base): class StorageAppState(Base): """Represents an app state stored in the database.""" + __tablename__ = "app_states" - app_name: Mapped[str] = mapped_column(String, primary_key=True) + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) state: Mapped[MutableDict[str, Any]] = mapped_column( MutableDict.as_mutable(DynamicJSON), default={} ) @@ -192,13 +218,20 @@ class StorageAppState(Base): DateTime(), default=func.now(), onupdate=func.now() ) - class StorageUserState(Base): """Represents a user state stored in the database.""" + __tablename__ = "user_states" - app_name: Mapped[str] = mapped_column(String, primary_key=True) - user_id: Mapped[str] = mapped_column(String, primary_key=True) + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True + ) + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(DynamicJSON), default={} + ) state: Mapped[MutableDict[str, Any]] = mapped_column( MutableDict.as_mutable(DynamicJSON), default={} )