feat! Update session service interface to be async.

Also keep the sync version in the InMemorySessionService as create_session_sync() as a temporary migration option.

PiperOrigin-RevId: 759252188
This commit is contained in:
Google Team Member
2025-05-15 12:23:33 -07:00
committed by Copybara-Service
parent 5b3204c356
commit 1804ca39a6
23 changed files with 268 additions and 264 deletions

View File

@@ -56,7 +56,7 @@ class ModelContent(types.Content):
super().__init__(role='model', parts=parts)
async def create_invocation_context(agent: Agent, user_content: str = ''):
def create_invocation_context(agent: Agent, user_content: str = ''):
invocation_id = 'test_id'
artifact_service = InMemoryArtifactService()
session_service = InMemorySessionService()
@@ -67,7 +67,7 @@ async def create_invocation_context(agent: Agent, user_content: str = ''):
memory_service=memory_service,
invocation_id=invocation_id,
agent=agent,
session=await session_service.create_session(
session=session_service.create_session(
app_name='test_app', user_id='test_user'
),
user_content=types.Content(
@@ -141,7 +141,7 @@ class TestInMemoryRunner(AfInMemoryRunner):
self, new_message: types.ContentUnion
) -> list[Event]:
session = await self.session_service.create_session(
session = self.session_service.create_session(
app_name='InMemoryRunner', user_id='test_user'
)
collected_events = []
@@ -172,22 +172,14 @@ class InMemoryRunner:
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
)
self.session_id = None
self.session_id = self.runner.session_service.create_session(
app_name='test_app', user_id='test_user'
).id
@property
def session(self) -> Session:
if not self.session_id:
session = asyncio.run(
self.runner.session_service.create_session(
app_name='test_app', user_id='test_user'
)
)
self.session_id = session.id
return session
return asyncio.run(
self.runner.session_service.get_session(
app_name='test_app', user_id='test_user', session_id=self.session_id
)
return self.runner.session_service.get_session(
app_name='test_app', user_id='test_user', session_id=self.session_id
)
def run(self, new_message: types.ContentUnion) -> list[Event]:
@@ -202,9 +194,9 @@ class InMemoryRunner:
def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]:
collected_responses = []
async def consume_responses(session: Session):
async def consume_responses():
run_res = self.runner.run_live(
session=session,
session=self.session,
live_request_queue=live_request_queue,
)
@@ -215,8 +207,7 @@ class InMemoryRunner:
return
try:
session = self.session
asyncio.run(consume_responses(session))
asyncio.run(consume_responses())
except asyncio.TimeoutError:
print('Returning any partial results collected so far.')