From bc43a1196aa8605e6db8fe879d2385bce2adb5e2 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Wed, 14 May 2025 14:22:04 -0700 Subject: [PATCH] feat: add _sync implementation in the inmemory session sevice. PiperOrigin-RevId: 758832846 --- .../adk/sessions/in_memory_session_service.py | 102 +++++++++++++++++- 1 file changed, 100 insertions(+), 2 deletions(-) diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index 69767f2..d37d901 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -13,6 +13,7 @@ # limitations under the License. import copy +import logging import time from typing import Any from typing import Optional @@ -28,12 +29,15 @@ from .base_session_service import ListSessionsResponse from .session import Session from .state import State +logger = logging.getLogger(__name__) + class InMemorySessionService(BaseSessionService): """An in-memory implementation of the session service.""" def __init__(self): - # A map from app name to a map from user ID to a map from session ID to session. + # A map from app name to a map from user ID to a map from session ID to + # session. self.sessions: dict[str, dict[str, dict[str, Session]]] = {} # A map from app name to a map from user ID to a map from key to the value. self.user_state: dict[str, dict[str, dict[str, Any]]] = {} @@ -48,6 +52,37 @@ class InMemorySessionService(BaseSessionService): user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + ) -> Session: + return self._create_session_impl( + app_name=app_name, + user_id=user_id, + state=state, + session_id=session_id, + ) + + def create_session_sync( + self, + *, + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + session_id: Optional[str] = None, + ) -> Session: + logger.warning('Deprecated. Please migrate to the async method.') + return self._create_session_impl( + app_name=app_name, + user_id=user_id, + state=state, + session_id=session_id, + ) + + def _create_session_impl( + self, + *, + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + session_id: Optional[str] = None, ) -> Session: session_id = ( session_id.strip() @@ -79,6 +114,37 @@ class InMemorySessionService(BaseSessionService): user_id: str, session_id: str, config: Optional[GetSessionConfig] = None, + ) -> Session: + return self._get_session_impl( + app_name=app_name, + user_id=user_id, + session_id=session_id, + config=config, + ) + + def get_session_sync( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: Optional[GetSessionConfig] = None, + ) -> Session: + logger.warning('Deprecated. Please migrate to the async method.') + return self._get_session_impl( + app_name=app_name, + user_id=user_id, + session_id=session_id, + config=config, + ) + + def _get_session_impl( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: Optional[GetSessionConfig] = None, ) -> Session: if app_name not in self.sessions: return None @@ -130,6 +196,17 @@ class InMemorySessionService(BaseSessionService): @override def list_sessions( self, *, app_name: str, user_id: str + ) -> ListSessionsResponse: + return self._list_sessions_impl(app_name=app_name, user_id=user_id) + + def list_sessions_sync( + self, *, app_name: str, user_id: str + ) -> ListSessionsResponse: + logger.warning('Deprecated. Please migrate to the async method.') + return self._list_sessions_impl(app_name=app_name, user_id=user_id) + + def _list_sessions_impl( + self, *, app_name: str, user_id: str ) -> ListSessionsResponse: empty_response = ListSessionsResponse() if app_name not in self.sessions: @@ -145,9 +222,23 @@ class InMemorySessionService(BaseSessionService): sessions_without_events.append(copied_session) return ListSessionsResponse(sessions=sessions_without_events) - @override def delete_session( self, *, app_name: str, user_id: str, session_id: str + ) -> None: + self._delete_session_impl( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + def delete_session_sync( + self, *, app_name: str, user_id: str, session_id: str + ) -> None: + logger.warning('Deprecated. Please migrate to the async method.') + self._delete_session_impl( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + def _delete_session_impl( + self, *, app_name: str, user_id: str, session_id: str ) -> None: if ( self.get_session( @@ -161,6 +252,13 @@ class InMemorySessionService(BaseSessionService): @override def append_event(self, session: Session, event: Event) -> Event: + return self._append_event_impl(session=session, event=event) + + def append_event_sync(self, session: Session, event: Event) -> Event: + logger.warning('Deprecated. Please migrate to the async method.') + return self._append_event_impl(session=session, event=event) + + def _append_event_impl(self, session: Session, event: Event) -> Event: # Update the in-memory session. super().append_event(session=session, event=event) session.last_update_time = event.timestamp