chore: Minor improvement to session service

- Add missing override.
- Add warning to failed actions.
- Remove unused import.
- Remove unused fields.
- Add type checking.

PiperOrigin-RevId: 767209697
This commit is contained in:
Shangjie Chen 2025-06-04 10:57:51 -07:00 committed by Copybara-Service
parent 54ed031d1a
commit c6e1e82efb
2 changed files with 30 additions and 12 deletions

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations
import copy import copy
import logging import logging
@ -223,6 +224,7 @@ class InMemorySessionService(BaseSessionService):
sessions_without_events.append(copied_session) sessions_without_events.append(copied_session)
return ListSessionsResponse(sessions=sessions_without_events) return ListSessionsResponse(sessions=sessions_without_events)
@override
async def delete_session( async def delete_session(
self, *, app_name: str, user_id: str, session_id: str self, *, app_name: str, user_id: str, session_id: str
) -> None: ) -> None:
@ -247,7 +249,7 @@ class InMemorySessionService(BaseSessionService):
) )
is None is None
): ):
return None return
self.sessions[app_name][user_id].pop(session_id) self.sessions[app_name][user_id].pop(session_id)
@ -261,11 +263,20 @@ class InMemorySessionService(BaseSessionService):
app_name = session.app_name app_name = session.app_name
user_id = session.user_id user_id = session.user_id
session_id = session.id session_id = session.id
def _warning(message: str) -> None:
logger.warning(
f'Failed to append event to session {session_id}: {message}'
)
if app_name not in self.sessions: if app_name not in self.sessions:
_warning(f'app_name {app_name} not in sessions')
return event return event
if user_id not in self.sessions[app_name]: if user_id not in self.sessions[app_name]:
_warning(f'user_id {user_id} not in sessions[app_name]')
return event return event
if session_id not in self.sessions[app_name][user_id]: if session_id not in self.sessions[app_name][user_id]:
_warning(f'session_id {session_id} not in sessions[app_name][user_id]')
return event return event
if event.actions and event.actions.state_delta: if event.actions and event.actions.state_delta:

View File

@ -16,8 +16,8 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import re import re
import time
from typing import Any from typing import Any
from typing import Dict
from typing import Optional from typing import Optional
import urllib.parse import urllib.parse
@ -50,9 +50,6 @@ class VertexAiSessionService(BaseSessionService):
self.project = project self.project = project
self.location = location self.location = location
client = genai.Client(vertexai=True, project=project, location=location)
self.api_client = client._api_client
@override @override
async def create_session( async def create_session(
self, self,
@ -86,6 +83,7 @@ class VertexAiSessionService(BaseSessionService):
operation_id = api_response['name'].split('/')[-1] operation_id = api_response['name'].split('/')[-1]
max_retry_attempt = 5 max_retry_attempt = 5
lro_response = None
while max_retry_attempt >= 0: while max_retry_attempt >= 0:
lro_response = await api_client.async_request( lro_response = await api_client.async_request(
http_method='GET', http_method='GET',
@ -99,6 +97,11 @@ class VertexAiSessionService(BaseSessionService):
await asyncio.sleep(1) await asyncio.sleep(1)
max_retry_attempt -= 1 max_retry_attempt -= 1
if lro_response is None or not lro_response.get('done', None):
raise TimeoutError(
f'Timeout waiting for operation {operation_id} to complete.'
)
# Get session resource # Get session resource
get_session_api_response = await api_client.async_request( get_session_api_response = await api_client.async_request(
http_method='GET', http_method='GET',
@ -235,11 +238,15 @@ class VertexAiSessionService(BaseSessionService):
) -> None: ) -> None:
reasoning_engine_id = _parse_reasoning_engine_id(app_name) reasoning_engine_id = _parse_reasoning_engine_id(app_name)
api_client = _get_api_client(self.project, self.location) api_client = _get_api_client(self.project, self.location)
await api_client.async_request( try:
http_method='DELETE', await api_client.async_request(
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', http_method='DELETE',
request_dict={}, path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
) request_dict={},
)
except Exception as e:
logger.error(f'Error deleting session {session_id}: {e}')
raise e
@override @override
async def append_event(self, session: Session, event: Event) -> Event: async def append_event(self, session: Session, event: Event) -> Event:
@ -266,7 +273,7 @@ def _get_api_client(project: str, location: str):
return client._api_client return client._api_client
def _convert_event_to_json(event: Event): def _convert_event_to_json(event: Event) -> Dict[str, Any]:
metadata_json = { metadata_json = {
'partial': event.partial, 'partial': event.partial,
'turn_complete': event.turn_complete, 'turn_complete': event.turn_complete,
@ -318,7 +325,7 @@ def _convert_event_to_json(event: Event):
return event_json return event_json
def _from_api_event(api_event: dict) -> Event: def _from_api_event(api_event: Dict[str, Any]) -> Event:
event_actions = EventActions() event_actions = EventActions()
if api_event.get('actions', None): if api_event.get('actions', None):
event_actions = EventActions( event_actions = EventActions(