feat: add _sync implementation in the inmemory session sevice.

PiperOrigin-RevId: 758832846
This commit is contained in:
Shangjie Chen 2025-05-14 14:22:04 -07:00 committed by Copybara-Service
parent dc90c91ed1
commit bc43a1196a

View File

@ -13,6 +13,7 @@
# limitations under the License.
import copy
import logging
import time
from typing import Any
from typing import Optional
@ -28,12 +29,15 @@ from .base_session_service import ListSessionsResponse
from .session import Session
from .state import State
logger = logging.getLogger(__name__)
class InMemorySessionService(BaseSessionService):
"""An in-memory implementation of the session service."""
def __init__(self):
# A map from app name to a map from user ID to a map from session ID to session.
# A map from app name to a map from user ID to a map from session ID to
# session.
self.sessions: dict[str, dict[str, dict[str, Session]]] = {}
# A map from app name to a map from user ID to a map from key to the value.
self.user_state: dict[str, dict[str, dict[str, Any]]] = {}
@ -48,6 +52,37 @@ class InMemorySessionService(BaseSessionService):
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
) -> Session:
return self._create_session_impl(
app_name=app_name,
user_id=user_id,
state=state,
session_id=session_id,
)
def create_session_sync(
self,
*,
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
) -> Session:
logger.warning('Deprecated. Please migrate to the async method.')
return self._create_session_impl(
app_name=app_name,
user_id=user_id,
state=state,
session_id=session_id,
)
def _create_session_impl(
self,
*,
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
) -> Session:
session_id = (
session_id.strip()
@ -79,6 +114,37 @@ class InMemorySessionService(BaseSessionService):
user_id: str,
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Session:
return self._get_session_impl(
app_name=app_name,
user_id=user_id,
session_id=session_id,
config=config,
)
def get_session_sync(
self,
*,
app_name: str,
user_id: str,
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Session:
logger.warning('Deprecated. Please migrate to the async method.')
return self._get_session_impl(
app_name=app_name,
user_id=user_id,
session_id=session_id,
config=config,
)
def _get_session_impl(
self,
*,
app_name: str,
user_id: str,
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Session:
if app_name not in self.sessions:
return None
@ -130,6 +196,17 @@ class InMemorySessionService(BaseSessionService):
@override
def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
def list_sessions_sync(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
logger.warning('Deprecated. Please migrate to the async method.')
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
def _list_sessions_impl(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
empty_response = ListSessionsResponse()
if app_name not in self.sessions:
@ -145,9 +222,23 @@ class InMemorySessionService(BaseSessionService):
sessions_without_events.append(copied_session)
return ListSessionsResponse(sessions=sessions_without_events)
@override
def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
self._delete_session_impl(
app_name=app_name, user_id=user_id, session_id=session_id
)
def delete_session_sync(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
logger.warning('Deprecated. Please migrate to the async method.')
self._delete_session_impl(
app_name=app_name, user_id=user_id, session_id=session_id
)
def _delete_session_impl(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
if (
self.get_session(
@ -161,6 +252,13 @@ class InMemorySessionService(BaseSessionService):
@override
def append_event(self, session: Session, event: Event) -> Event:
return self._append_event_impl(session=session, event=event)
def append_event_sync(self, session: Session, event: Event) -> Event:
logger.warning('Deprecated. Please migrate to the async method.')
return self._append_event_impl(session=session, event=event)
def _append_event_impl(self, session: Session, event: Event) -> Event:
# Update the in-memory session.
super().append_event(session=session, event=event)
session.last_update_time = event.timestamp