diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index 282e8e5..b2a84ef 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import copy import logging @@ -223,6 +224,7 @@ class InMemorySessionService(BaseSessionService): sessions_without_events.append(copied_session) return ListSessionsResponse(sessions=sessions_without_events) + @override async def delete_session( self, *, app_name: str, user_id: str, session_id: str ) -> None: @@ -247,7 +249,7 @@ class InMemorySessionService(BaseSessionService): ) is None ): - return None + return self.sessions[app_name][user_id].pop(session_id) @@ -261,11 +263,20 @@ class InMemorySessionService(BaseSessionService): app_name = session.app_name user_id = session.user_id session_id = session.id + + def _warning(message: str) -> None: + logger.warning( + f'Failed to append event to session {session_id}: {message}' + ) + if app_name not in self.sessions: + _warning(f'app_name {app_name} not in sessions') return event if user_id not in self.sessions[app_name]: + _warning(f'user_id {user_id} not in sessions[app_name]') return event if session_id not in self.sessions[app_name][user_id]: + _warning(f'session_id {session_id} not in sessions[app_name][user_id]') return event if event.actions and event.actions.state_delta: diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 2cff001..5d6bed2 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -16,8 +16,8 @@ from __future__ import annotations import asyncio import logging import re -import time from typing import Any +from typing import Dict from typing import Optional import urllib.parse @@ -50,9 +50,6 @@ class VertexAiSessionService(BaseSessionService): self.project = project self.location = location - client = genai.Client(vertexai=True, project=project, location=location) - self.api_client = client._api_client - @override async def create_session( self, @@ -86,6 +83,7 @@ class VertexAiSessionService(BaseSessionService): operation_id = api_response['name'].split('/')[-1] max_retry_attempt = 5 + lro_response = None while max_retry_attempt >= 0: lro_response = await api_client.async_request( http_method='GET', @@ -99,6 +97,11 @@ class VertexAiSessionService(BaseSessionService): await asyncio.sleep(1) max_retry_attempt -= 1 + if lro_response is None or not lro_response.get('done', None): + raise TimeoutError( + f'Timeout waiting for operation {operation_id} to complete.' + ) + # Get session resource get_session_api_response = await api_client.async_request( http_method='GET', @@ -235,11 +238,15 @@ class VertexAiSessionService(BaseSessionService): ) -> None: reasoning_engine_id = _parse_reasoning_engine_id(app_name) api_client = _get_api_client(self.project, self.location) - await api_client.async_request( - http_method='DELETE', - path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', - request_dict={}, - ) + try: + await api_client.async_request( + http_method='DELETE', + path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', + request_dict={}, + ) + except Exception as e: + logger.error(f'Error deleting session {session_id}: {e}') + raise e @override async def append_event(self, session: Session, event: Event) -> Event: @@ -266,7 +273,7 @@ def _get_api_client(project: str, location: str): return client._api_client -def _convert_event_to_json(event: Event): +def _convert_event_to_json(event: Event) -> Dict[str, Any]: metadata_json = { 'partial': event.partial, 'turn_complete': event.turn_complete, @@ -318,7 +325,7 @@ def _convert_event_to_json(event: Event): return event_json -def _from_api_event(api_event: dict) -> Event: +def _from_api_event(api_event: Dict[str, Any]) -> Event: event_actions = EventActions() if api_event.get('actions', None): event_actions = EventActions(