adk-python/src/google/adk/sessions/vertex_ai_session_service.py
Shangjie Chen e99f87de73 chore: Minor improvement to session service
- Add missing override.
- Add warning to failed actions.
- Remove unused import.
- Remove unused fields.
- Add type checking.

PiperOrigin-RevId: 766882634
2025-06-03 17:13:33 -07:00

385 lines
12 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import logging
import re
from typing import Any
from typing import Dict
from typing import Optional
import urllib.parse
from dateutil import parser
from google.genai import types
from typing_extensions import override
from google import genai
from . import _session_util
from ..events.event import Event
from ..events.event_actions import EventActions
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListSessionsResponse
from .session import Session
isoparse = parser.isoparse
logger = logging.getLogger('google_adk.' + __name__)
class VertexAiSessionService(BaseSessionService):
"""Connects to the managed Vertex AI Session Service."""
def __init__(
self,
project: str = None,
location: str = None,
):
self.project = project
self.location = location
@override
async def create_session(
self,
*,
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
) -> Session:
if session_id:
raise ValueError(
'User-provided Session id is not supported for'
' VertexAISessionService.'
)
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
session_json_dict = {'user_id': user_id}
if state:
session_json_dict['session_state'] = state
api_client = _get_api_client(self.project, self.location)
api_response = await api_client.async_request(
http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions',
request_dict=session_json_dict,
)
logger.info(f'Create Session response {api_response}')
session_id = api_response['name'].split('/')[-3]
operation_id = api_response['name'].split('/')[-1]
max_retry_attempt = 5
lro_response = None
while max_retry_attempt >= 0:
lro_response = await api_client.async_request(
http_method='GET',
path=f'operations/{operation_id}',
request_dict={},
)
if lro_response.get('done', None):
break
await asyncio.sleep(1)
max_retry_attempt -= 1
if lro_response is None or not lro_response.get('done', None):
raise TimeoutError(
f'Timeout waiting for operation {operation_id} to complete.'
)
# Get session resource
get_session_api_response = await api_client.async_request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
)
update_timestamp = isoparse(
get_session_api_response['updateTime']
).timestamp()
session = Session(
app_name=str(app_name),
user_id=str(user_id),
id=str(session_id),
state=get_session_api_response.get('sessionState', {}),
last_update_time=update_timestamp,
)
return session
@override
async def get_session(
self,
*,
app_name: str,
user_id: str,
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Optional[Session]:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
# Get session resource
api_client = _get_api_client(self.project, self.location)
get_session_api_response = await api_client.async_request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
)
session_id = get_session_api_response['name'].split('/')[-1]
update_timestamp = isoparse(
get_session_api_response['updateTime']
).timestamp()
session = Session(
app_name=str(app_name),
user_id=str(user_id),
id=str(session_id),
state=get_session_api_response.get('sessionState', {}),
last_update_time=update_timestamp,
)
list_events_api_response = await api_client.async_request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
request_dict={},
)
# Handles empty response case
if list_events_api_response.get('httpHeaders', None):
return session
session.events += [
_from_api_event(event)
for event in list_events_api_response['sessionEvents']
]
while list_events_api_response.get('nextPageToken', None):
page_token = list_events_api_response.get('nextPageToken', None)
list_events_api_response = await api_client.async_request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events?pageToken={page_token}',
request_dict={},
)
session.events += [
_from_api_event(event)
for event in list_events_api_response['sessionEvents']
]
session.events = [
event for event in session.events if event.timestamp <= update_timestamp
]
session.events.sort(key=lambda event: event.timestamp)
# Filter events based on config
if config:
if config.num_recent_events:
session.events = session.events[-config.num_recent_events :]
elif config.after_timestamp:
i = len(session.events) - 1
while i >= 0:
if session.events[i].timestamp < config.after_timestamp:
break
i -= 1
if i >= 0:
session.events = session.events[i:]
return session
@override
async def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
path = f'reasoningEngines/{reasoning_engine_id}/sessions'
if user_id:
parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe='')
path = path + f'?filter=user_id={parsed_user_id}'
api_client = _get_api_client(self.project, self.location)
api_response = await api_client.async_request(
http_method='GET',
path=path,
request_dict={},
)
# Handles empty response case
if api_response.get('httpHeaders', None):
return ListSessionsResponse()
sessions = []
for api_session in api_response['sessions']:
session = Session(
app_name=app_name,
user_id=user_id,
id=api_session['name'].split('/')[-1],
state={},
last_update_time=isoparse(api_session['updateTime']).timestamp(),
)
sessions.append(session)
return ListSessionsResponse(sessions=sessions)
async def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
api_client = _get_api_client(self.project, self.location)
try:
await api_client.async_request(
http_method='DELETE',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
)
except Exception as e:
logger.error(f'Error deleting session {session_id}: {e}')
raise e
@override
async def append_event(self, session: Session, event: Event) -> Event:
# Update the in-memory session.
await super().append_event(session=session, event=event)
reasoning_engine_id = _parse_reasoning_engine_id(session.app_name)
api_client = _get_api_client(self.project, self.location)
await api_client.async_request(
http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent',
request_dict=_convert_event_to_json(event),
)
return event
def _get_api_client(project: str, location: str):
"""Instantiates an API client for the given project and location.
It needs to be instantiated inside each request so that the event loop
management.
"""
client = genai.Client(vertexai=True, project=project, location=location)
return client._api_client
def _convert_event_to_json(event: Event) -> Dict[str, Any]:
metadata_json = {
'partial': event.partial,
'turn_complete': event.turn_complete,
'interrupted': event.interrupted,
'branch': event.branch,
'long_running_tool_ids': (
list(event.long_running_tool_ids)
if event.long_running_tool_ids
else None
),
}
if event.grounding_metadata:
metadata_json['grounding_metadata'] = event.grounding_metadata.model_dump(
exclude_none=True, mode='json'
)
event_json = {
'author': event.author,
'invocation_id': event.invocation_id,
'timestamp': {
'seconds': int(event.timestamp),
'nanos': int(
(event.timestamp - int(event.timestamp)) * 1_000_000_000
),
},
'error_code': event.error_code,
'error_message': event.error_message,
'event_metadata': metadata_json,
}
if event.actions:
actions_json = {
'skip_summarization': event.actions.skip_summarization,
'state_delta': event.actions.state_delta,
'artifact_delta': event.actions.artifact_delta,
'transfer_agent': event.actions.transfer_to_agent,
'escalate': event.actions.escalate,
'requested_auth_configs': event.actions.requested_auth_configs,
}
event_json['actions'] = actions_json
if event.content:
event_json['content'] = event.content.model_dump(
exclude_none=True, mode='json'
)
if event.error_code:
event_json['error_code'] = event.error_code
if event.error_message:
event_json['error_message'] = event.error_message
return event_json
def _from_api_event(api_event: Dict[str, Any]) -> Event:
event_actions = EventActions()
if api_event.get('actions', None):
event_actions = EventActions(
skip_summarization=api_event['actions'].get('skipSummarization', None),
state_delta=api_event['actions'].get('stateDelta', {}),
artifact_delta=api_event['actions'].get('artifactDelta', {}),
transfer_to_agent=api_event['actions'].get('transferAgent', None),
escalate=api_event['actions'].get('escalate', None),
requested_auth_configs=api_event['actions'].get(
'requestedAuthConfigs', {}
),
)
event = Event(
id=api_event['name'].split('/')[-1],
invocation_id=api_event['invocationId'],
author=api_event['author'],
actions=event_actions,
content=_session_util.decode_content(api_event.get('content', None)),
timestamp=isoparse(api_event['timestamp']).timestamp(),
error_code=api_event.get('errorCode', None),
error_message=api_event.get('errorMessage', None),
)
if api_event.get('eventMetadata', None):
long_running_tool_ids_list = api_event['eventMetadata'].get(
'longRunningToolIds', None
)
event.partial = api_event['eventMetadata'].get('partial', None)
event.turn_complete = api_event['eventMetadata'].get('turnComplete', None)
event.interrupted = api_event['eventMetadata'].get('interrupted', None)
event.branch = api_event['eventMetadata'].get('branch', None)
event.grounding_metadata = _session_util.decode_grounding_metadata(
api_event['eventMetadata'].get('groundingMetadata', None)
)
event.long_running_tool_ids = (
set(long_running_tool_ids_list) if long_running_tool_ids_list else None
)
return event
def _parse_reasoning_engine_id(app_name: str):
if app_name.isdigit():
return app_name
pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$'
match = re.fullmatch(pattern, app_name)
if not bool(match):
raise ValueError(
f'App name {app_name} is not valid. It should either be the full'
' ReasoningEngine resource name, or the reasoning engine id.'
)
return match.groups()[-1]