# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import asyncio import logging import re from typing import Any from typing import Dict from typing import Optional import urllib.parse from dateutil import parser from google.genai import types from typing_extensions import override from google import genai from . import _session_util from ..events.event import Event from ..events.event_actions import EventActions from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig from .base_session_service import ListSessionsResponse from .session import Session isoparse = parser.isoparse logger = logging.getLogger('google_adk.' + __name__) class VertexAiSessionService(BaseSessionService): """Connects to the managed Vertex AI Session Service.""" def __init__( self, project: str = None, location: str = None, ): self.project = project self.location = location @override async def create_session( self, *, app_name: str, user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, ) -> Session: if session_id: raise ValueError( 'User-provided Session id is not supported for' ' VertexAISessionService.' ) reasoning_engine_id = _parse_reasoning_engine_id(app_name) session_json_dict = {'user_id': user_id} if state: session_json_dict['session_state'] = state api_client = _get_api_client(self.project, self.location) api_response = await api_client.async_request( http_method='POST', path=f'reasoningEngines/{reasoning_engine_id}/sessions', request_dict=session_json_dict, ) logger.info(f'Create Session response {api_response}') session_id = api_response['name'].split('/')[-3] operation_id = api_response['name'].split('/')[-1] max_retry_attempt = 5 lro_response = None while max_retry_attempt >= 0: lro_response = await api_client.async_request( http_method='GET', path=f'operations/{operation_id}', request_dict={}, ) if lro_response.get('done', None): break await asyncio.sleep(1) max_retry_attempt -= 1 if lro_response is None or not lro_response.get('done', None): raise TimeoutError( f'Timeout waiting for operation {operation_id} to complete.' ) # Get session resource get_session_api_response = await api_client.async_request( http_method='GET', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, ) update_timestamp = isoparse( get_session_api_response['updateTime'] ).timestamp() session = Session( app_name=str(app_name), user_id=str(user_id), id=str(session_id), state=get_session_api_response.get('sessionState', {}), last_update_time=update_timestamp, ) return session @override async def get_session( self, *, app_name: str, user_id: str, session_id: str, config: Optional[GetSessionConfig] = None, ) -> Optional[Session]: reasoning_engine_id = _parse_reasoning_engine_id(app_name) # Get session resource api_client = _get_api_client(self.project, self.location) get_session_api_response = await api_client.async_request( http_method='GET', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, ) session_id = get_session_api_response['name'].split('/')[-1] update_timestamp = isoparse( get_session_api_response['updateTime'] ).timestamp() session = Session( app_name=str(app_name), user_id=str(user_id), id=str(session_id), state=get_session_api_response.get('sessionState', {}), last_update_time=update_timestamp, ) list_events_api_response = await api_client.async_request( http_method='GET', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events', request_dict={}, ) # Handles empty response case if list_events_api_response.get('httpHeaders', None): return session session.events += [ _from_api_event(event) for event in list_events_api_response['sessionEvents'] ] while list_events_api_response.get('nextPageToken', None): page_token = list_events_api_response.get('nextPageToken', None) list_events_api_response = await api_client.async_request( http_method='GET', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events?pageToken={page_token}', request_dict={}, ) session.events += [ _from_api_event(event) for event in list_events_api_response['sessionEvents'] ] session.events = [ event for event in session.events if event.timestamp <= update_timestamp ] session.events.sort(key=lambda event: event.timestamp) # Filter events based on config if config: if config.num_recent_events: session.events = session.events[-config.num_recent_events :] elif config.after_timestamp: i = len(session.events) - 1 while i >= 0: if session.events[i].timestamp < config.after_timestamp: break i -= 1 if i >= 0: session.events = session.events[i:] return session @override async def list_sessions( self, *, app_name: str, user_id: str ) -> ListSessionsResponse: reasoning_engine_id = _parse_reasoning_engine_id(app_name) path = f'reasoningEngines/{reasoning_engine_id}/sessions' if user_id: parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe='') path = path + f'?filter=user_id={parsed_user_id}' api_client = _get_api_client(self.project, self.location) api_response = await api_client.async_request( http_method='GET', path=path, request_dict={}, ) # Handles empty response case if api_response.get('httpHeaders', None): return ListSessionsResponse() sessions = [] for api_session in api_response['sessions']: session = Session( app_name=app_name, user_id=user_id, id=api_session['name'].split('/')[-1], state={}, last_update_time=isoparse(api_session['updateTime']).timestamp(), ) sessions.append(session) return ListSessionsResponse(sessions=sessions) async def delete_session( self, *, app_name: str, user_id: str, session_id: str ) -> None: reasoning_engine_id = _parse_reasoning_engine_id(app_name) api_client = _get_api_client(self.project, self.location) try: await api_client.async_request( http_method='DELETE', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, ) except Exception as e: logger.error(f'Error deleting session {session_id}: {e}') raise e @override async def append_event(self, session: Session, event: Event) -> Event: # Update the in-memory session. await super().append_event(session=session, event=event) reasoning_engine_id = _parse_reasoning_engine_id(session.app_name) api_client = _get_api_client(self.project, self.location) await api_client.async_request( http_method='POST', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent', request_dict=_convert_event_to_json(event), ) return event def _get_api_client(project: str, location: str): """Instantiates an API client for the given project and location. It needs to be instantiated inside each request so that the event loop management. """ client = genai.Client(vertexai=True, project=project, location=location) return client._api_client def _convert_event_to_json(event: Event) -> Dict[str, Any]: metadata_json = { 'partial': event.partial, 'turn_complete': event.turn_complete, 'interrupted': event.interrupted, 'branch': event.branch, 'long_running_tool_ids': ( list(event.long_running_tool_ids) if event.long_running_tool_ids else None ), } if event.grounding_metadata: metadata_json['grounding_metadata'] = event.grounding_metadata.model_dump( exclude_none=True, mode='json' ) event_json = { 'author': event.author, 'invocation_id': event.invocation_id, 'timestamp': { 'seconds': int(event.timestamp), 'nanos': int( (event.timestamp - int(event.timestamp)) * 1_000_000_000 ), }, 'error_code': event.error_code, 'error_message': event.error_message, 'event_metadata': metadata_json, } if event.actions: actions_json = { 'skip_summarization': event.actions.skip_summarization, 'state_delta': event.actions.state_delta, 'artifact_delta': event.actions.artifact_delta, 'transfer_agent': event.actions.transfer_to_agent, 'escalate': event.actions.escalate, 'requested_auth_configs': event.actions.requested_auth_configs, } event_json['actions'] = actions_json if event.content: event_json['content'] = event.content.model_dump( exclude_none=True, mode='json' ) if event.error_code: event_json['error_code'] = event.error_code if event.error_message: event_json['error_message'] = event.error_message return event_json def _from_api_event(api_event: Dict[str, Any]) -> Event: event_actions = EventActions() if api_event.get('actions', None): event_actions = EventActions( skip_summarization=api_event['actions'].get('skipSummarization', None), state_delta=api_event['actions'].get('stateDelta', {}), artifact_delta=api_event['actions'].get('artifactDelta', {}), transfer_to_agent=api_event['actions'].get('transferAgent', None), escalate=api_event['actions'].get('escalate', None), requested_auth_configs=api_event['actions'].get( 'requestedAuthConfigs', {} ), ) event = Event( id=api_event['name'].split('/')[-1], invocation_id=api_event['invocationId'], author=api_event['author'], actions=event_actions, content=_session_util.decode_content(api_event.get('content', None)), timestamp=isoparse(api_event['timestamp']).timestamp(), error_code=api_event.get('errorCode', None), error_message=api_event.get('errorMessage', None), ) if api_event.get('eventMetadata', None): long_running_tool_ids_list = api_event['eventMetadata'].get( 'longRunningToolIds', None ) event.partial = api_event['eventMetadata'].get('partial', None) event.turn_complete = api_event['eventMetadata'].get('turnComplete', None) event.interrupted = api_event['eventMetadata'].get('interrupted', None) event.branch = api_event['eventMetadata'].get('branch', None) event.grounding_metadata = _session_util.decode_grounding_metadata( api_event['eventMetadata'].get('groundingMetadata', None) ) event.long_running_tool_ids = ( set(long_running_tool_ids_list) if long_running_tool_ids_list else None ) return event def _parse_reasoning_engine_id(app_name: str): if app_name.isdigit(): return app_name pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$' match = re.fullmatch(pattern, app_name) if not bool(match): raise ValueError( f'App name {app_name} is not valid. It should either be the full' ' ReasoningEngine resource name, or the reasoning engine id.' ) return match.groups()[-1]