feat! Update session service interface to be async.

Also keep the sync version in the InMemorySessionService as create_session_sync() as a temporary migration option.

PiperOrigin-RevId: 759252188
This commit is contained in:
Google Team Member
2025-05-15 12:23:33 -07:00
committed by Copybara-Service
parent 5b3204c356
commit 1804ca39a6
23 changed files with 268 additions and 264 deletions

View File

@@ -55,7 +55,7 @@ async def run_input_file(
input_file = InputFile.model_validate_json(f.read())
input_file.state['_time'] = datetime.now()
session = await session_service.create_session(
session = session_service.create_session(
app_name=app_name, user_id=user_id, state=input_file.state
)
for query in input_file.queries:
@@ -130,7 +130,7 @@ async def run_cli(
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
agent_module = importlib.import_module(agent_folder_name)
user_id = 'test_user'
session = await session_service.create_session(
session = session_service.create_session(
app_name=agent_folder_name, user_id=user_id
)
root_agent = agent_module.agent.root_agent
@@ -145,12 +145,14 @@ async def run_cli(
input_path=input_file,
)
elif saved_session_file:
loaded_session = None
with open(saved_session_file, 'r') as f:
loaded_session = Session.model_validate_json(f.read())
if loaded_session:
for event in loaded_session.events:
await session_service.append_event(session, event)
session_service.append_event(session, event)
content = event.content
if not content or not content.parts or not content.parts[0].text:
continue
@@ -179,7 +181,7 @@ async def run_cli(
session_path = f'{agent_module_path}/{session_id}.session.json'
# Fetch the session again to get all the details.
session = await session_service.get_session(
session = session_service.get_session(
app_name=session.app_name,
user_id=session.user_id,
session_id=session.id,

View File

@@ -333,12 +333,10 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
response_model_exclude_none=True,
)
async def get_session(
app_name: str, user_id: str, session_id: str
) -> Session:
def get_session(app_name: str, user_id: str, session_id: str) -> Session:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
session = await session_service.get_session(
session = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
if not session:
@@ -349,15 +347,14 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions",
response_model_exclude_none=True,
)
async def list_sessions(app_name: str, user_id: str) -> list[Session]:
def list_sessions(app_name: str, user_id: str) -> list[Session]:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
list_sessions_response = await session_service.list_sessions(
app_name=app_name, user_id=user_id
)
return [
session
for session in list_sessions_response.sessions
for session in session_service.list_sessions(
app_name=app_name, user_id=user_id
).sessions
# Remove sessions that were generated as a part of Eval.
if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
]
@@ -366,7 +363,7 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
response_model_exclude_none=True,
)
async def create_session_with_id(
def create_session_with_id(
app_name: str,
user_id: str,
session_id: str,
@@ -375,7 +372,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
if (
await session_service.get_session(
session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
is not None
@@ -385,7 +382,7 @@ def get_fast_api_app(
status_code=400, detail=f"Session already exists: {session_id}"
)
logger.info("New session created: %s", session_id)
return await session_service.create_session(
return session_service.create_session(
app_name=app_name, user_id=user_id, state=state, session_id=session_id
)
@@ -393,7 +390,7 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions",
response_model_exclude_none=True,
)
async def create_session(
def create_session(
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
@@ -401,7 +398,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
logger.info("New session created")
return await session_service.create_session(
return session_service.create_session(
app_name=app_name, user_id=user_id, state=state
)
@@ -445,7 +442,7 @@ def get_fast_api_app(
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
):
# Get the session
session = await session_service.get_session(
session = session_service.get_session(
app_name=app_name, user_id=req.user_id, session_id=req.session_id
)
assert session, "Session not found."
@@ -533,7 +530,7 @@ def get_fast_api_app(
session_id=eval_case_result.session_id,
)
)
eval_case_result.session_details = await session_service.get_session(
eval_case_result.session_details = session_service.get_session(
app_name=app_name,
user_id=eval_case_result.user_id,
session_id=eval_case_result.session_id,
@@ -618,10 +615,10 @@ def get_fast_api_app(
return eval_result_files
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
async def delete_session(app_name: str, user_id: str, session_id: str):
def delete_session(app_name: str, user_id: str, session_id: str):
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
await session_service.delete_session(
session_service.delete_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
@@ -716,7 +713,7 @@ def get_fast_api_app(
async def agent_run(req: AgentRunRequest) -> list[Event]:
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name
session = await session_service.get_session(
session = session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id
)
if not session:
@@ -738,7 +735,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name
# SSE endpoint
session = await session_service.get_session(
session = session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id
)
if not session:
@@ -779,7 +776,7 @@ def get_fast_api_app(
):
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name
session = await session_service.get_session(
session = session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id
)
session_events = session.events if session else []
@@ -836,7 +833,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name
session = await session_service.get_session(
session = session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id
)
if not session:

View File

@@ -126,7 +126,7 @@ class EvaluationGenerator:
user_id = initial_session.user_id if initial_session else "test_user_id"
session_id = session_id if session_id else str(uuid.uuid4())
_ = await session_service.create_session(
_ = session_service.create_session(
app_name=app_name,
user_id=user_id,
state=initial_session.state if initial_session else {},

View File

@@ -173,7 +173,7 @@ class Runner:
The events generated by the agent.
"""
with tracer.start_as_current_span('invocation'):
session = await self.session_service.get_session(
session = self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
)
if not session:
@@ -197,7 +197,7 @@ class Runner:
invocation_context.agent = self._find_agent_to_run(session, root_agent)
async for event in invocation_context.agent.run_async(invocation_context):
if not event.partial:
await self.session_service.append_event(session=session, event=event)
self.session_service.append_event(session=session, event=event)
yield event
async def _append_new_message_to_session(
@@ -242,7 +242,7 @@ class Runner:
author='user',
content=new_message,
)
await self.session_service.append_event(session=session, event=event)
self.session_service.append_event(session=session, event=event)
async def run_live(
self,
@@ -324,7 +324,7 @@ class Runner:
)
async for event in invocation_context.agent.run_live(invocation_context):
await self.session_service.append_event(session=session, event=event)
self.session_service.append_event(session=session, event=event)
yield event
async def close_session(self, session: Session):
@@ -335,7 +335,7 @@ class Runner:
"""
if self.memory_service:
await self.memory_service.add_session_to_memory(session)
await self.session_service.close_session(session=session)
self.session_service.close_session(session=session)
def _find_agent_to_run(
self, session: Session, root_agent: BaseAgent

View File

@@ -47,7 +47,7 @@ class BaseSessionService(abc.ABC):
"""
@abc.abstractmethod
async def create_session(
def create_session(
self,
*,
app_name: str,
@@ -67,9 +67,10 @@ class BaseSessionService(abc.ABC):
Returns:
session: The newly created session instance.
"""
pass
@abc.abstractmethod
async def get_session(
def get_session(
self,
*,
app_name: str,
@@ -78,24 +79,28 @@ class BaseSessionService(abc.ABC):
config: Optional[GetSessionConfig] = None,
) -> Optional[Session]:
"""Gets a session."""
pass
@abc.abstractmethod
async def list_sessions(
def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
"""Lists all the sessions."""
pass
@abc.abstractmethod
async def delete_session(
def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
"""Deletes a session."""
pass
async def close_session(self, *, session: Session):
def close_session(self, *, session: Session):
"""Closes a session."""
# TODO: determine whether we want to finalize the session here.
pass
async def append_event(self, session: Session, event: Event) -> Event:
def append_event(self, session: Session, event: Event) -> Event:
"""Appends an event to a session object."""
if event.partial:
return event

View File

@@ -283,7 +283,7 @@ class DatabaseSessionService(BaseSessionService):
Base.metadata.create_all(self.db_engine)
@override
async def create_session(
def create_session(
self,
*,
app_name: str,
@@ -357,7 +357,7 @@ class DatabaseSessionService(BaseSessionService):
return session
@override
async def get_session(
def get_session(
self,
*,
app_name: str,
@@ -431,7 +431,7 @@ class DatabaseSessionService(BaseSessionService):
return session
@override
async def list_sessions(
def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
with self.DatabaseSessionFactory() as sessionFactory:
@@ -454,7 +454,7 @@ class DatabaseSessionService(BaseSessionService):
return ListSessionsResponse(sessions=sessions)
@override
async def delete_session(
def delete_session(
self, app_name: str, user_id: str, session_id: str
) -> None:
with self.DatabaseSessionFactory() as sessionFactory:
@@ -467,7 +467,7 @@ class DatabaseSessionService(BaseSessionService):
sessionFactory.commit()
@override
async def append_event(self, session: Session, event: Event) -> Event:
def append_event(self, session: Session, event: Event) -> Event:
logger.info(f"Append event: {event} to session {session.id}")
if event.partial:
@@ -552,10 +552,9 @@ class DatabaseSessionService(BaseSessionService):
session.last_update_time = storage_session.update_time.timestamp()
# Also update the in-memory session
await super().append_event(session=session, event=event)
super().append_event(session=session, event=event)
return event
def convert_event(event: StorageEvent) -> Event:
"""Converts a storage event to an event."""
return Event(

View File

@@ -44,7 +44,7 @@ class InMemorySessionService(BaseSessionService):
self.app_state: dict[str, dict[str, Any]] = {}
@override
async def create_session(
def create_session(
self,
*,
app_name: str,
@@ -106,7 +106,7 @@ class InMemorySessionService(BaseSessionService):
return self._merge_state(app_name, user_id, copied_session)
@override
async def get_session(
def get_session(
self,
*,
app_name: str,
@@ -193,7 +193,7 @@ class InMemorySessionService(BaseSessionService):
return copied_session
@override
async def list_sessions(
def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
@@ -221,7 +221,7 @@ class InMemorySessionService(BaseSessionService):
sessions_without_events.append(copied_session)
return ListSessionsResponse(sessions=sessions_without_events)
async def delete_session(
def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
self._delete_session_impl(
@@ -250,9 +250,16 @@ class InMemorySessionService(BaseSessionService):
self.sessions[app_name][user_id].pop(session_id)
@override
async def append_event(self, session: Session, event: Event) -> Event:
def append_event(self, session: Session, event: Event) -> Event:
return self._append_event_impl(session=session, event=event)
def append_event_sync(self, session: Session, event: Event) -> Event:
logger.warning('Deprecated. Please migrate to the async method.')
return self._append_event_impl(session=session, event=event)
def _append_event_impl(self, session: Session, event: Event) -> Event:
# Update the in-memory session.
await super().append_event(session=session, event=event)
super().append_event(session=session, event=event)
session.last_update_time = event.timestamp
# Update the storage session
@@ -279,7 +286,7 @@ class InMemorySessionService(BaseSessionService):
] = event.actions.state_delta[key]
storage_session = self.sessions[app_name][user_id].get(session_id)
await super().append_event(session=storage_session, event=event)
super().append_event(session=storage_session, event=event)
storage_session.last_update_time = event.timestamp

View File

@@ -48,7 +48,7 @@ class VertexAiSessionService(BaseSessionService):
self.api_client = client._api_client
@override
async def create_session(
def create_session(
self,
*,
app_name: str,
@@ -68,7 +68,7 @@ class VertexAiSessionService(BaseSessionService):
if state:
session_json_dict['session_state'] = state
api_response = await self.api_client.async_request(
api_response = self.api_client.request(
http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions',
request_dict=session_json_dict,
@@ -80,7 +80,7 @@ class VertexAiSessionService(BaseSessionService):
max_retry_attempt = 5
while max_retry_attempt >= 0:
lro_response = await self.api_client.async_request(
lro_response = self.api_client.request(
http_method='GET',
path=f'operations/{operation_id}',
request_dict={},
@@ -93,7 +93,7 @@ class VertexAiSessionService(BaseSessionService):
max_retry_attempt -= 1
# Get session resource
get_session_api_response = await self.api_client.async_request(
get_session_api_response = self.api_client.request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
@@ -112,7 +112,7 @@ class VertexAiSessionService(BaseSessionService):
return session
@override
async def get_session(
def get_session(
self,
*,
app_name: str,
@@ -123,7 +123,7 @@ class VertexAiSessionService(BaseSessionService):
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
# Get session resource
get_session_api_response = await self.api_client.async_request(
get_session_api_response = self.api_client.request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
@@ -141,7 +141,7 @@ class VertexAiSessionService(BaseSessionService):
last_update_time=update_timestamp,
)
list_events_api_response = await self.api_client.async_request(
list_events_api_response = self.api_client.request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
request_dict={},
@@ -175,7 +175,7 @@ class VertexAiSessionService(BaseSessionService):
return session
@override
async def list_sessions(
def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
@@ -202,23 +202,23 @@ class VertexAiSessionService(BaseSessionService):
sessions.append(session)
return ListSessionsResponse(sessions=sessions)
async def delete_session(
def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
await self.api_client.async_request(
self.api_client.request(
http_method='DELETE',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
)
@override
async def append_event(self, session: Session, event: Event) -> Event:
def append_event(self, session: Session, event: Event) -> Event:
# Update the in-memory session.
await super().append_event(session=session, event=event)
super().append_event(session=session, event=event)
reasoning_engine_id = _parse_reasoning_engine_id(session.app_name)
await self.api_client.async_request(
self.api_client.request(
http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent',
request_dict=_convert_event_to_json(event),

View File

@@ -129,7 +129,7 @@ class AgentTool(BaseTool):
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
)
session = await runner.session_service.create_session(
session = runner.session_service.create_session(
app_name=self.agent.name,
user_id='tmp_user',
state=tool_context.state.to_dict(),