feat! Update session service interface to be async.

Also keep the sync version in the InMemorySessionService as create_session_sync() as a temporary migration option.

PiperOrigin-RevId: 759224250
This commit is contained in:
Shangjie Chen
2025-05-15 11:16:43 -07:00
committed by Copybara-Service
parent d161a2c3f7
commit 5b3204c356
23 changed files with 264 additions and 268 deletions

View File

@@ -56,7 +56,7 @@ class ModelContent(types.Content):
super().__init__(role='model', parts=parts)
def create_invocation_context(agent: Agent, user_content: str = ''):
async def create_invocation_context(agent: Agent, user_content: str = ''):
invocation_id = 'test_id'
artifact_service = InMemoryArtifactService()
session_service = InMemorySessionService()
@@ -67,7 +67,7 @@ def create_invocation_context(agent: Agent, user_content: str = ''):
memory_service=memory_service,
invocation_id=invocation_id,
agent=agent,
session=session_service.create_session(
session=await session_service.create_session(
app_name='test_app', user_id='test_user'
),
user_content=types.Content(
@@ -141,7 +141,7 @@ class TestInMemoryRunner(AfInMemoryRunner):
self, new_message: types.ContentUnion
) -> list[Event]:
session = self.session_service.create_session(
session = await self.session_service.create_session(
app_name='InMemoryRunner', user_id='test_user'
)
collected_events = []
@@ -172,14 +172,22 @@ class InMemoryRunner:
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
)
self.session_id = self.runner.session_service.create_session(
app_name='test_app', user_id='test_user'
).id
self.session_id = None
@property
def session(self) -> Session:
return self.runner.session_service.get_session(
app_name='test_app', user_id='test_user', session_id=self.session_id
if not self.session_id:
session = asyncio.run(
self.runner.session_service.create_session(
app_name='test_app', user_id='test_user'
)
)
self.session_id = session.id
return session
return asyncio.run(
self.runner.session_service.get_session(
app_name='test_app', user_id='test_user', session_id=self.session_id
)
)
def run(self, new_message: types.ContentUnion) -> list[Event]:
@@ -194,9 +202,9 @@ class InMemoryRunner:
def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]:
collected_responses = []
async def consume_responses():
async def consume_responses(session: Session):
run_res = self.runner.run_live(
session=self.session,
session=session,
live_request_queue=live_request_queue,
)
@@ -207,7 +215,8 @@ class InMemoryRunner:
return
try:
asyncio.run(consume_responses())
session = self.session
asyncio.run(consume_responses(session))
except asyncio.TimeoutError:
print('Returning any partial results collected so far.')