adk-python/src/google/adk/sessions/vertex_ai_session_service.py

337 lines
10 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
import time
from typing import Any
from typing import Optional
from dateutil import parser
from google import genai
from typing_extensions import override
from ..events.event import Event
from ..events.event_actions import EventActions
from . import _session_util
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListSessionsResponse
from .session import Session
isoparse = parser.isoparse
logger = logging.getLogger('google_adk.' + __name__)
class VertexAiSessionService(BaseSessionService):
"""Connects to the managed Vertex AI Session Service."""
def __init__(
self,
project: str = None,
location: str = None,
):
self.project = project
self.location = location
client = genai.Client(vertexai=True, project=project, location=location)
self.api_client = client._api_client
@override
async def create_session(
self,
*,
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
) -> Session:
if session_id:
raise ValueError(
'User-provided Session id is not supported for'
' VertexAISessionService.'
)
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
session_json_dict = {'user_id': user_id}
if state:
session_json_dict['session_state'] = state
api_response = self.api_client.request(
http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions',
request_dict=session_json_dict,
)
logger.info(f'Create Session response {api_response}')
session_id = api_response['name'].split('/')[-3]
operation_id = api_response['name'].split('/')[-1]
max_retry_attempt = 5
while max_retry_attempt >= 0:
lro_response = self.api_client.request(
http_method='GET',
path=f'operations/{operation_id}',
request_dict={},
)
if lro_response.get('done', None):
break
time.sleep(1)
max_retry_attempt -= 1
# Get session resource
get_session_api_response = self.api_client.request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
)
update_timestamp = isoparse(
get_session_api_response['updateTime']
).timestamp()
session = Session(
app_name=str(app_name),
user_id=str(user_id),
id=str(session_id),
state=get_session_api_response.get('sessionState', {}),
last_update_time=update_timestamp,
)
return session
@override
async def get_session(
self,
*,
app_name: str,
user_id: str,
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Optional[Session]:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
# Get session resource
get_session_api_response = self.api_client.request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
)
session_id = get_session_api_response['name'].split('/')[-1]
update_timestamp = isoparse(
get_session_api_response['updateTime']
).timestamp()
session = Session(
app_name=str(app_name),
user_id=str(user_id),
id=str(session_id),
state=get_session_api_response.get('sessionState', {}),
last_update_time=update_timestamp,
)
list_events_api_response = self.api_client.request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
request_dict={},
)
# Handles empty response case
if list_events_api_response.get('httpHeaders', None):
return session
session.events = [
_from_api_event(event)
for event in list_events_api_response['sessionEvents']
]
session.events = [
event for event in session.events if event.timestamp <= update_timestamp
]
session.events.sort(key=lambda event: event.timestamp)
if config:
if config.num_recent_events:
session.events = session.events[-config.num_recent_events :]
elif config.after_timestamp:
i = len(session.events) - 1
while i >= 0:
if session.events[i].timestamp < config.after_timestamp:
break
i -= 1
if i >= 0:
session.events = session.events[i:]
return session
@override
async def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
api_response = self.api_client.request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions?filter=user_id={user_id}',
request_dict={},
)
# Handles empty response case
if api_response.get('httpHeaders', None):
return ListSessionsResponse()
sessions = []
for api_session in api_response['sessions']:
session = Session(
app_name=app_name,
user_id=user_id,
id=api_session['name'].split('/')[-1],
state={},
last_update_time=isoparse(api_session['updateTime']).timestamp(),
)
sessions.append(session)
return ListSessionsResponse(sessions=sessions)
async def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
self.api_client.request(
http_method='DELETE',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
)
@override
async def append_event(self, session: Session, event: Event) -> Event:
# Update the in-memory session.
await super().append_event(session=session, event=event)
reasoning_engine_id = _parse_reasoning_engine_id(session.app_name)
self.api_client.request(
http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent',
request_dict=_convert_event_to_json(event),
)
return event
def _convert_event_to_json(event: Event):
metadata_json = {
'partial': event.partial,
'turn_complete': event.turn_complete,
'interrupted': event.interrupted,
'branch': event.branch,
'long_running_tool_ids': (
list(event.long_running_tool_ids)
if event.long_running_tool_ids
else None
),
}
if event.grounding_metadata:
metadata_json['grounding_metadata'] = event.grounding_metadata.model_dump(
exclude_none=True
)
event_json = {
'author': event.author,
'invocation_id': event.invocation_id,
'timestamp': {
'seconds': int(event.timestamp),
'nanos': int(
(event.timestamp - int(event.timestamp)) * 1_000_000_000
),
},
'error_code': event.error_code,
'error_message': event.error_message,
'event_metadata': metadata_json,
}
if event.actions:
actions_json = {
'skip_summarization': event.actions.skip_summarization,
'state_delta': event.actions.state_delta,
'artifact_delta': event.actions.artifact_delta,
'transfer_agent': event.actions.transfer_to_agent,
'escalate': event.actions.escalate,
'requested_auth_configs': event.actions.requested_auth_configs,
}
event_json['actions'] = actions_json
if event.content:
event_json['content'] = _session_util.encode_content(event.content)
if event.error_code:
event_json['error_code'] = event.error_code
if event.error_message:
event_json['error_message'] = event.error_message
return event_json
def _from_api_event(api_event: dict) -> Event:
event_actions = EventActions()
if api_event.get('actions', None):
event_actions = EventActions(
skip_summarization=api_event['actions'].get('skipSummarization', None),
state_delta=api_event['actions'].get('stateDelta', {}),
artifact_delta=api_event['actions'].get('artifactDelta', {}),
transfer_to_agent=api_event['actions'].get('transferAgent', None),
escalate=api_event['actions'].get('escalate', None),
requested_auth_configs=api_event['actions'].get(
'requestedAuthConfigs', {}
),
)
event = Event(
id=api_event['name'].split('/')[-1],
invocation_id=api_event['invocationId'],
author=api_event['author'],
actions=event_actions,
content=_session_util.decode_content(api_event.get('content', None)),
timestamp=isoparse(api_event['timestamp']).timestamp(),
error_code=api_event.get('errorCode', None),
error_message=api_event.get('errorMessage', None),
)
if api_event.get('eventMetadata', None):
long_running_tool_ids_list = api_event['eventMetadata'].get(
'longRunningToolIds', None
)
event.partial = api_event['eventMetadata'].get('partial', None)
event.turn_complete = api_event['eventMetadata'].get('turnComplete', None)
event.interrupted = api_event['eventMetadata'].get('interrupted', None)
event.branch = api_event['eventMetadata'].get('branch', None)
event.grounding_metadata = api_event['eventMetadata'].get(
'groundingMetadata', None
)
event.long_running_tool_ids = (
set(long_running_tool_ids_list) if long_running_tool_ids_list else None
)
return event
def _parse_reasoning_engine_id(app_name: str):
if app_name.isdigit():
return app_name
pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$'
match = re.fullmatch(pattern, app_name)
if not bool(match):
raise ValueError(
f'App name {app_name} is not valid. It should either be the full'
' ReasoningEngine resource name, or the reasoning engine id.'
)
return match.groups()[-1]