diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index c15414a..fdf0c8a 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -61,10 +61,7 @@ DEFAULT_MAX_VARCHAR_LENGTH = 256 class DynamicJSON(TypeDecorator): - """A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON - - serialization for other databases. - """ + """A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON serialization for other databases.""" impl = Text # Default implementation is TEXT @@ -242,10 +239,7 @@ class DatabaseSessionService(BaseSessionService): """A session service that uses a database for storage.""" def __init__(self, db_url: str): - """ - Args: - db_url: The database URL to connect to. - """ + """Initializes the database session service with a database URL.""" # 1. Create DB engine for db connection # 2. Create all tables based on schema # 3. Initialize all properties @@ -274,7 +268,7 @@ class DatabaseSessionService(BaseSessionService): self.inspector = inspect(self.db_engine) # DB session factory method - self.DatabaseSessionFactory: sessionmaker[DatabaseSessionFactory] = ( + self.database_session_factory: sessionmaker[DatabaseSessionFactory] = ( sessionmaker(bind=self.db_engine) ) @@ -297,11 +291,11 @@ class DatabaseSessionService(BaseSessionService): # 4. Build the session object with generated id # 5. Return the session - with self.DatabaseSessionFactory() as sessionFactory: + with self.database_session_factory() as session_factory: # Fetch app and user states from storage - storage_app_state = sessionFactory.get(StorageAppState, (app_name)) - storage_user_state = sessionFactory.get( + storage_app_state = session_factory.get(StorageAppState, (app_name)) + storage_user_state = session_factory.get( StorageUserState, (app_name, user_id) ) @@ -311,12 +305,12 @@ class DatabaseSessionService(BaseSessionService): # Create state tables if not exist if not storage_app_state: storage_app_state = StorageAppState(app_name=app_name, state={}) - sessionFactory.add(storage_app_state) + session_factory.add(storage_app_state) if not storage_user_state: storage_user_state = StorageUserState( app_name=app_name, user_id=user_id, state={} ) - sessionFactory.add(storage_user_state) + session_factory.add(storage_user_state) # Extract state deltas app_state_delta, user_state_delta, session_state = _extract_state_delta( @@ -340,10 +334,10 @@ class DatabaseSessionService(BaseSessionService): id=session_id, state=session_state, ) - sessionFactory.add(storage_session) - sessionFactory.commit() + session_factory.add(storage_session) + session_factory.commit() - sessionFactory.refresh(storage_session) + session_factory.refresh(storage_session) # Merge states for response merged_state = _merge_state(app_state, user_state, session_state) @@ -368,31 +362,37 @@ class DatabaseSessionService(BaseSessionService): # 1. Get the storage session entry from session table # 2. Get all the events based on session id and filtering config # 3. Convert and return the session - with self.DatabaseSessionFactory() as sessionFactory: - storage_session = sessionFactory.get( + with self.database_session_factory() as session_factory: + storage_session = session_factory.get( StorageSession, (app_name, user_id, session_id) ) if storage_session is None: return None - + if config and config.after_timestamp: - after_dt = datetime.fromtimestamp(config.after_timestamp, tz=timezone.utc) + after_dt = datetime.fromtimestamp( + config.after_timestamp, tz=timezone.utc + ) timestamp_filter = StorageEvent.timestamp > after_dt else: timestamp_filter = True storage_events = ( - sessionFactory.query(StorageEvent) + session_factory.query(StorageEvent) .filter(StorageEvent.session_id == storage_session.id) - .filter(timestamp_filter) + .filter(timestamp_filter) .order_by(StorageEvent.timestamp.asc()) - .limit(config.num_recent_events if config and config.num_recent_events else None) + .limit( + config.num_recent_events + if config and config.num_recent_events + else None + ) .all() ) # Fetch states from storage - storage_app_state = sessionFactory.get(StorageAppState, (app_name)) - storage_user_state = sessionFactory.get( + storage_app_state = session_factory.get(StorageAppState, (app_name)) + storage_user_state = session_factory.get( StorageUserState, (app_name, user_id) ) @@ -436,9 +436,9 @@ class DatabaseSessionService(BaseSessionService): async def list_sessions( self, *, app_name: str, user_id: str ) -> ListSessionsResponse: - with self.DatabaseSessionFactory() as sessionFactory: + with self.database_session_factory() as session_factory: results = ( - sessionFactory.query(StorageSession) + session_factory.query(StorageSession) .filter(StorageSession.app_name == app_name) .filter(StorageSession.user_id == user_id) .all() @@ -459,14 +459,14 @@ class DatabaseSessionService(BaseSessionService): async def delete_session( self, app_name: str, user_id: str, session_id: str ) -> None: - with self.DatabaseSessionFactory() as sessionFactory: + with self.database_session_factory() as session_factory: stmt = delete(StorageSession).where( StorageSession.app_name == app_name, StorageSession.user_id == user_id, StorageSession.id == session_id, ) - sessionFactory.execute(stmt) - sessionFactory.commit() + session_factory.execute(stmt) + session_factory.commit() @override async def append_event(self, session: Session, event: Event) -> Event: @@ -478,8 +478,8 @@ class DatabaseSessionService(BaseSessionService): # 1. Check if timestamp is stale # 2. Update session attributes based on event config # 3. Store event to table - with self.DatabaseSessionFactory() as sessionFactory: - storage_session = sessionFactory.get( + with self.database_session_factory() as session_factory: + storage_session = session_factory.get( StorageSession, (session.app_name, session.user_id, session.id) ) @@ -493,10 +493,10 @@ class DatabaseSessionService(BaseSessionService): ) # Fetch states from storage - storage_app_state = sessionFactory.get( + storage_app_state = session_factory.get( StorageAppState, (session.app_name) ) - storage_user_state = sessionFactory.get( + storage_user_state = session_factory.get( StorageUserState, (session.app_name, session.user_id) ) @@ -545,10 +545,10 @@ class DatabaseSessionService(BaseSessionService): if event.content: storage_event.content = _session_util.encode_content(event.content) - sessionFactory.add(storage_event) + session_factory.add(storage_event) - sessionFactory.commit() - sessionFactory.refresh(storage_session) + session_factory.commit() + session_factory.refresh(storage_session) # Update timestamp with commit time session.last_update_time = storage_session.update_time.timestamp()