From 0127c3f9d87376e70b4c665743728d8bba383e61 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 3 Jun 2025 18:57:15 -0700 Subject: [PATCH] chore: Minor improvement to session service - Add missing override. - Add warning to failed actions. - Remove unused import. - Remove unused fields. - Add type checking. PiperOrigin-RevId: 766913196 --- .../adk/sessions/in_memory_session_service.py | 13 +-------- .../adk/sessions/vertex_ai_session_service.py | 29 +++++++------------ 2 files changed, 12 insertions(+), 30 deletions(-) diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index b2a84ef..282e8e5 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations import copy import logging @@ -224,7 +223,6 @@ class InMemorySessionService(BaseSessionService): sessions_without_events.append(copied_session) return ListSessionsResponse(sessions=sessions_without_events) - @override async def delete_session( self, *, app_name: str, user_id: str, session_id: str ) -> None: @@ -249,7 +247,7 @@ class InMemorySessionService(BaseSessionService): ) is None ): - return + return None self.sessions[app_name][user_id].pop(session_id) @@ -263,20 +261,11 @@ class InMemorySessionService(BaseSessionService): app_name = session.app_name user_id = session.user_id session_id = session.id - - def _warning(message: str) -> None: - logger.warning( - f'Failed to append event to session {session_id}: {message}' - ) - if app_name not in self.sessions: - _warning(f'app_name {app_name} not in sessions') return event if user_id not in self.sessions[app_name]: - _warning(f'user_id {user_id} not in sessions[app_name]') return event if session_id not in self.sessions[app_name][user_id]: - _warning(f'session_id {session_id} not in sessions[app_name][user_id]') return event if event.actions and event.actions.state_delta: diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 5d6bed2..2cff001 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -16,8 +16,8 @@ from __future__ import annotations import asyncio import logging import re +import time from typing import Any -from typing import Dict from typing import Optional import urllib.parse @@ -50,6 +50,9 @@ class VertexAiSessionService(BaseSessionService): self.project = project self.location = location + client = genai.Client(vertexai=True, project=project, location=location) + self.api_client = client._api_client + @override async def create_session( self, @@ -83,7 +86,6 @@ class VertexAiSessionService(BaseSessionService): operation_id = api_response['name'].split('/')[-1] max_retry_attempt = 5 - lro_response = None while max_retry_attempt >= 0: lro_response = await api_client.async_request( http_method='GET', @@ -97,11 +99,6 @@ class VertexAiSessionService(BaseSessionService): await asyncio.sleep(1) max_retry_attempt -= 1 - if lro_response is None or not lro_response.get('done', None): - raise TimeoutError( - f'Timeout waiting for operation {operation_id} to complete.' - ) - # Get session resource get_session_api_response = await api_client.async_request( http_method='GET', @@ -238,15 +235,11 @@ class VertexAiSessionService(BaseSessionService): ) -> None: reasoning_engine_id = _parse_reasoning_engine_id(app_name) api_client = _get_api_client(self.project, self.location) - try: - await api_client.async_request( - http_method='DELETE', - path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', - request_dict={}, - ) - except Exception as e: - logger.error(f'Error deleting session {session_id}: {e}') - raise e + await api_client.async_request( + http_method='DELETE', + path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', + request_dict={}, + ) @override async def append_event(self, session: Session, event: Event) -> Event: @@ -273,7 +266,7 @@ def _get_api_client(project: str, location: str): return client._api_client -def _convert_event_to_json(event: Event) -> Dict[str, Any]: +def _convert_event_to_json(event: Event): metadata_json = { 'partial': event.partial, 'turn_complete': event.turn_complete, @@ -325,7 +318,7 @@ def _convert_event_to_json(event: Event) -> Dict[str, Any]: return event_json -def _from_api_event(api_event: Dict[str, Any]) -> Event: +def _from_api_event(api_event: dict) -> Event: event_actions = EventActions() if api_event.get('actions', None): event_actions = EventActions(