feat! Update session service interface to be async.

Also keep the sync version in the InMemorySessionService as create_session_sync() as a temporary migration option.

PiperOrigin-RevId: 759252188
This commit is contained in:
Google Team Member 2025-05-15 12:23:33 -07:00 committed by Copybara-Service
parent 5b3204c356
commit 1804ca39a6
23 changed files with 268 additions and 264 deletions

View File

@ -55,7 +55,7 @@ async def run_input_file(
input_file = InputFile.model_validate_json(f.read()) input_file = InputFile.model_validate_json(f.read())
input_file.state['_time'] = datetime.now() input_file.state['_time'] = datetime.now()
session = await session_service.create_session( session = session_service.create_session(
app_name=app_name, user_id=user_id, state=input_file.state app_name=app_name, user_id=user_id, state=input_file.state
) )
for query in input_file.queries: for query in input_file.queries:
@ -130,7 +130,7 @@ async def run_cli(
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name) agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
agent_module = importlib.import_module(agent_folder_name) agent_module = importlib.import_module(agent_folder_name)
user_id = 'test_user' user_id = 'test_user'
session = await session_service.create_session( session = session_service.create_session(
app_name=agent_folder_name, user_id=user_id app_name=agent_folder_name, user_id=user_id
) )
root_agent = agent_module.agent.root_agent root_agent = agent_module.agent.root_agent
@ -145,12 +145,14 @@ async def run_cli(
input_path=input_file, input_path=input_file,
) )
elif saved_session_file: elif saved_session_file:
loaded_session = None
with open(saved_session_file, 'r') as f: with open(saved_session_file, 'r') as f:
loaded_session = Session.model_validate_json(f.read()) loaded_session = Session.model_validate_json(f.read())
if loaded_session: if loaded_session:
for event in loaded_session.events: for event in loaded_session.events:
await session_service.append_event(session, event) session_service.append_event(session, event)
content = event.content content = event.content
if not content or not content.parts or not content.parts[0].text: if not content or not content.parts or not content.parts[0].text:
continue continue
@ -179,7 +181,7 @@ async def run_cli(
session_path = f'{agent_module_path}/{session_id}.session.json' session_path = f'{agent_module_path}/{session_id}.session.json'
# Fetch the session again to get all the details. # Fetch the session again to get all the details.
session = await session_service.get_session( session = session_service.get_session(
app_name=session.app_name, app_name=session.app_name,
user_id=session.user_id, user_id=session.user_id,
session_id=session.id, session_id=session.id,

View File

@ -333,12 +333,10 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}", "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
response_model_exclude_none=True, response_model_exclude_none=True,
) )
async def get_session( def get_session(app_name: str, user_id: str, session_id: str) -> Session:
app_name: str, user_id: str, session_id: str
) -> Session:
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name app_name = agent_engine_id if agent_engine_id else app_name
session = await session_service.get_session( session = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id app_name=app_name, user_id=user_id, session_id=session_id
) )
if not session: if not session:
@ -349,15 +347,14 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions", "/apps/{app_name}/users/{user_id}/sessions",
response_model_exclude_none=True, response_model_exclude_none=True,
) )
async def list_sessions(app_name: str, user_id: str) -> list[Session]: def list_sessions(app_name: str, user_id: str) -> list[Session]:
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name app_name = agent_engine_id if agent_engine_id else app_name
list_sessions_response = await session_service.list_sessions(
app_name=app_name, user_id=user_id
)
return [ return [
session session
for session in list_sessions_response.sessions for session in session_service.list_sessions(
app_name=app_name, user_id=user_id
).sessions
# Remove sessions that were generated as a part of Eval. # Remove sessions that were generated as a part of Eval.
if not session.id.startswith(EVAL_SESSION_ID_PREFIX) if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
] ]
@ -366,7 +363,7 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}", "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
response_model_exclude_none=True, response_model_exclude_none=True,
) )
async def create_session_with_id( def create_session_with_id(
app_name: str, app_name: str,
user_id: str, user_id: str,
session_id: str, session_id: str,
@ -375,7 +372,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name app_name = agent_engine_id if agent_engine_id else app_name
if ( if (
await session_service.get_session( session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id app_name=app_name, user_id=user_id, session_id=session_id
) )
is not None is not None
@ -385,7 +382,7 @@ def get_fast_api_app(
status_code=400, detail=f"Session already exists: {session_id}" status_code=400, detail=f"Session already exists: {session_id}"
) )
logger.info("New session created: %s", session_id) logger.info("New session created: %s", session_id)
return await session_service.create_session( return session_service.create_session(
app_name=app_name, user_id=user_id, state=state, session_id=session_id app_name=app_name, user_id=user_id, state=state, session_id=session_id
) )
@ -393,7 +390,7 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions", "/apps/{app_name}/users/{user_id}/sessions",
response_model_exclude_none=True, response_model_exclude_none=True,
) )
async def create_session( def create_session(
app_name: str, app_name: str,
user_id: str, user_id: str,
state: Optional[dict[str, Any]] = None, state: Optional[dict[str, Any]] = None,
@ -401,7 +398,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name app_name = agent_engine_id if agent_engine_id else app_name
logger.info("New session created") logger.info("New session created")
return await session_service.create_session( return session_service.create_session(
app_name=app_name, user_id=user_id, state=state app_name=app_name, user_id=user_id, state=state
) )
@ -445,7 +442,7 @@ def get_fast_api_app(
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
): ):
# Get the session # Get the session
session = await session_service.get_session( session = session_service.get_session(
app_name=app_name, user_id=req.user_id, session_id=req.session_id app_name=app_name, user_id=req.user_id, session_id=req.session_id
) )
assert session, "Session not found." assert session, "Session not found."
@ -533,7 +530,7 @@ def get_fast_api_app(
session_id=eval_case_result.session_id, session_id=eval_case_result.session_id,
) )
) )
eval_case_result.session_details = await session_service.get_session( eval_case_result.session_details = session_service.get_session(
app_name=app_name, app_name=app_name,
user_id=eval_case_result.user_id, user_id=eval_case_result.user_id,
session_id=eval_case_result.session_id, session_id=eval_case_result.session_id,
@ -618,10 +615,10 @@ def get_fast_api_app(
return eval_result_files return eval_result_files
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}") @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
async def delete_session(app_name: str, user_id: str, session_id: str): def delete_session(app_name: str, user_id: str, session_id: str):
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name app_name = agent_engine_id if agent_engine_id else app_name
await session_service.delete_session( session_service.delete_session(
app_name=app_name, user_id=user_id, session_id=session_id app_name=app_name, user_id=user_id, session_id=session_id
) )
@ -716,7 +713,7 @@ def get_fast_api_app(
async def agent_run(req: AgentRunRequest) -> list[Event]: async def agent_run(req: AgentRunRequest) -> list[Event]:
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name app_id = agent_engine_id if agent_engine_id else req.app_name
session = await session_service.get_session( session = session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id app_name=app_id, user_id=req.user_id, session_id=req.session_id
) )
if not session: if not session:
@ -738,7 +735,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name app_id = agent_engine_id if agent_engine_id else req.app_name
# SSE endpoint # SSE endpoint
session = await session_service.get_session( session = session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id app_name=app_id, user_id=req.user_id, session_id=req.session_id
) )
if not session: if not session:
@ -779,7 +776,7 @@ def get_fast_api_app(
): ):
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name app_id = agent_engine_id if agent_engine_id else app_name
session = await session_service.get_session( session = session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id app_name=app_id, user_id=user_id, session_id=session_id
) )
session_events = session.events if session else [] session_events = session.events if session else []
@ -836,7 +833,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name app_id = agent_engine_id if agent_engine_id else app_name
session = await session_service.get_session( session = session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id app_name=app_id, user_id=user_id, session_id=session_id
) )
if not session: if not session:

View File

@ -126,7 +126,7 @@ class EvaluationGenerator:
user_id = initial_session.user_id if initial_session else "test_user_id" user_id = initial_session.user_id if initial_session else "test_user_id"
session_id = session_id if session_id else str(uuid.uuid4()) session_id = session_id if session_id else str(uuid.uuid4())
_ = await session_service.create_session( _ = session_service.create_session(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
state=initial_session.state if initial_session else {}, state=initial_session.state if initial_session else {},

View File

@ -173,7 +173,7 @@ class Runner:
The events generated by the agent. The events generated by the agent.
""" """
with tracer.start_as_current_span('invocation'): with tracer.start_as_current_span('invocation'):
session = await self.session_service.get_session( session = self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id app_name=self.app_name, user_id=user_id, session_id=session_id
) )
if not session: if not session:
@ -197,7 +197,7 @@ class Runner:
invocation_context.agent = self._find_agent_to_run(session, root_agent) invocation_context.agent = self._find_agent_to_run(session, root_agent)
async for event in invocation_context.agent.run_async(invocation_context): async for event in invocation_context.agent.run_async(invocation_context):
if not event.partial: if not event.partial:
await self.session_service.append_event(session=session, event=event) self.session_service.append_event(session=session, event=event)
yield event yield event
async def _append_new_message_to_session( async def _append_new_message_to_session(
@ -242,7 +242,7 @@ class Runner:
author='user', author='user',
content=new_message, content=new_message,
) )
await self.session_service.append_event(session=session, event=event) self.session_service.append_event(session=session, event=event)
async def run_live( async def run_live(
self, self,
@ -324,7 +324,7 @@ class Runner:
) )
async for event in invocation_context.agent.run_live(invocation_context): async for event in invocation_context.agent.run_live(invocation_context):
await self.session_service.append_event(session=session, event=event) self.session_service.append_event(session=session, event=event)
yield event yield event
async def close_session(self, session: Session): async def close_session(self, session: Session):
@ -335,7 +335,7 @@ class Runner:
""" """
if self.memory_service: if self.memory_service:
await self.memory_service.add_session_to_memory(session) await self.memory_service.add_session_to_memory(session)
await self.session_service.close_session(session=session) self.session_service.close_session(session=session)
def _find_agent_to_run( def _find_agent_to_run(
self, session: Session, root_agent: BaseAgent self, session: Session, root_agent: BaseAgent

View File

@ -47,7 +47,7 @@ class BaseSessionService(abc.ABC):
""" """
@abc.abstractmethod @abc.abstractmethod
async def create_session( def create_session(
self, self,
*, *,
app_name: str, app_name: str,
@ -67,9 +67,10 @@ class BaseSessionService(abc.ABC):
Returns: Returns:
session: The newly created session instance. session: The newly created session instance.
""" """
pass
@abc.abstractmethod @abc.abstractmethod
async def get_session( def get_session(
self, self,
*, *,
app_name: str, app_name: str,
@ -78,24 +79,28 @@ class BaseSessionService(abc.ABC):
config: Optional[GetSessionConfig] = None, config: Optional[GetSessionConfig] = None,
) -> Optional[Session]: ) -> Optional[Session]:
"""Gets a session.""" """Gets a session."""
pass
@abc.abstractmethod @abc.abstractmethod
async def list_sessions( def list_sessions(
self, *, app_name: str, user_id: str self, *, app_name: str, user_id: str
) -> ListSessionsResponse: ) -> ListSessionsResponse:
"""Lists all the sessions.""" """Lists all the sessions."""
pass
@abc.abstractmethod @abc.abstractmethod
async def delete_session( def delete_session(
self, *, app_name: str, user_id: str, session_id: str self, *, app_name: str, user_id: str, session_id: str
) -> None: ) -> None:
"""Deletes a session.""" """Deletes a session."""
pass
async def close_session(self, *, session: Session): def close_session(self, *, session: Session):
"""Closes a session.""" """Closes a session."""
# TODO: determine whether we want to finalize the session here. # TODO: determine whether we want to finalize the session here.
pass
async def append_event(self, session: Session, event: Event) -> Event: def append_event(self, session: Session, event: Event) -> Event:
"""Appends an event to a session object.""" """Appends an event to a session object."""
if event.partial: if event.partial:
return event return event

View File

@ -283,7 +283,7 @@ class DatabaseSessionService(BaseSessionService):
Base.metadata.create_all(self.db_engine) Base.metadata.create_all(self.db_engine)
@override @override
async def create_session( def create_session(
self, self,
*, *,
app_name: str, app_name: str,
@ -357,7 +357,7 @@ class DatabaseSessionService(BaseSessionService):
return session return session
@override @override
async def get_session( def get_session(
self, self,
*, *,
app_name: str, app_name: str,
@ -431,7 +431,7 @@ class DatabaseSessionService(BaseSessionService):
return session return session
@override @override
async def list_sessions( def list_sessions(
self, *, app_name: str, user_id: str self, *, app_name: str, user_id: str
) -> ListSessionsResponse: ) -> ListSessionsResponse:
with self.DatabaseSessionFactory() as sessionFactory: with self.DatabaseSessionFactory() as sessionFactory:
@ -454,7 +454,7 @@ class DatabaseSessionService(BaseSessionService):
return ListSessionsResponse(sessions=sessions) return ListSessionsResponse(sessions=sessions)
@override @override
async def delete_session( def delete_session(
self, app_name: str, user_id: str, session_id: str self, app_name: str, user_id: str, session_id: str
) -> None: ) -> None:
with self.DatabaseSessionFactory() as sessionFactory: with self.DatabaseSessionFactory() as sessionFactory:
@ -467,7 +467,7 @@ class DatabaseSessionService(BaseSessionService):
sessionFactory.commit() sessionFactory.commit()
@override @override
async def append_event(self, session: Session, event: Event) -> Event: def append_event(self, session: Session, event: Event) -> Event:
logger.info(f"Append event: {event} to session {session.id}") logger.info(f"Append event: {event} to session {session.id}")
if event.partial: if event.partial:
@ -552,10 +552,9 @@ class DatabaseSessionService(BaseSessionService):
session.last_update_time = storage_session.update_time.timestamp() session.last_update_time = storage_session.update_time.timestamp()
# Also update the in-memory session # Also update the in-memory session
await super().append_event(session=session, event=event) super().append_event(session=session, event=event)
return event return event
def convert_event(event: StorageEvent) -> Event: def convert_event(event: StorageEvent) -> Event:
"""Converts a storage event to an event.""" """Converts a storage event to an event."""
return Event( return Event(

View File

@ -44,7 +44,7 @@ class InMemorySessionService(BaseSessionService):
self.app_state: dict[str, dict[str, Any]] = {} self.app_state: dict[str, dict[str, Any]] = {}
@override @override
async def create_session( def create_session(
self, self,
*, *,
app_name: str, app_name: str,
@ -106,7 +106,7 @@ class InMemorySessionService(BaseSessionService):
return self._merge_state(app_name, user_id, copied_session) return self._merge_state(app_name, user_id, copied_session)
@override @override
async def get_session( def get_session(
self, self,
*, *,
app_name: str, app_name: str,
@ -193,7 +193,7 @@ class InMemorySessionService(BaseSessionService):
return copied_session return copied_session
@override @override
async def list_sessions( def list_sessions(
self, *, app_name: str, user_id: str self, *, app_name: str, user_id: str
) -> ListSessionsResponse: ) -> ListSessionsResponse:
return self._list_sessions_impl(app_name=app_name, user_id=user_id) return self._list_sessions_impl(app_name=app_name, user_id=user_id)
@ -221,7 +221,7 @@ class InMemorySessionService(BaseSessionService):
sessions_without_events.append(copied_session) sessions_without_events.append(copied_session)
return ListSessionsResponse(sessions=sessions_without_events) return ListSessionsResponse(sessions=sessions_without_events)
async def delete_session( def delete_session(
self, *, app_name: str, user_id: str, session_id: str self, *, app_name: str, user_id: str, session_id: str
) -> None: ) -> None:
self._delete_session_impl( self._delete_session_impl(
@ -250,9 +250,16 @@ class InMemorySessionService(BaseSessionService):
self.sessions[app_name][user_id].pop(session_id) self.sessions[app_name][user_id].pop(session_id)
@override @override
async def append_event(self, session: Session, event: Event) -> Event: def append_event(self, session: Session, event: Event) -> Event:
return self._append_event_impl(session=session, event=event)
def append_event_sync(self, session: Session, event: Event) -> Event:
logger.warning('Deprecated. Please migrate to the async method.')
return self._append_event_impl(session=session, event=event)
def _append_event_impl(self, session: Session, event: Event) -> Event:
# Update the in-memory session. # Update the in-memory session.
await super().append_event(session=session, event=event) super().append_event(session=session, event=event)
session.last_update_time = event.timestamp session.last_update_time = event.timestamp
# Update the storage session # Update the storage session
@ -279,7 +286,7 @@ class InMemorySessionService(BaseSessionService):
] = event.actions.state_delta[key] ] = event.actions.state_delta[key]
storage_session = self.sessions[app_name][user_id].get(session_id) storage_session = self.sessions[app_name][user_id].get(session_id)
await super().append_event(session=storage_session, event=event) super().append_event(session=storage_session, event=event)
storage_session.last_update_time = event.timestamp storage_session.last_update_time = event.timestamp

View File

@ -48,7 +48,7 @@ class VertexAiSessionService(BaseSessionService):
self.api_client = client._api_client self.api_client = client._api_client
@override @override
async def create_session( def create_session(
self, self,
*, *,
app_name: str, app_name: str,
@ -68,7 +68,7 @@ class VertexAiSessionService(BaseSessionService):
if state: if state:
session_json_dict['session_state'] = state session_json_dict['session_state'] = state
api_response = await self.api_client.async_request( api_response = self.api_client.request(
http_method='POST', http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions', path=f'reasoningEngines/{reasoning_engine_id}/sessions',
request_dict=session_json_dict, request_dict=session_json_dict,
@ -80,7 +80,7 @@ class VertexAiSessionService(BaseSessionService):
max_retry_attempt = 5 max_retry_attempt = 5
while max_retry_attempt >= 0: while max_retry_attempt >= 0:
lro_response = await self.api_client.async_request( lro_response = self.api_client.request(
http_method='GET', http_method='GET',
path=f'operations/{operation_id}', path=f'operations/{operation_id}',
request_dict={}, request_dict={},
@ -93,7 +93,7 @@ class VertexAiSessionService(BaseSessionService):
max_retry_attempt -= 1 max_retry_attempt -= 1
# Get session resource # Get session resource
get_session_api_response = await self.api_client.async_request( get_session_api_response = self.api_client.request(
http_method='GET', http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={}, request_dict={},
@ -112,7 +112,7 @@ class VertexAiSessionService(BaseSessionService):
return session return session
@override @override
async def get_session( def get_session(
self, self,
*, *,
app_name: str, app_name: str,
@ -123,7 +123,7 @@ class VertexAiSessionService(BaseSessionService):
reasoning_engine_id = _parse_reasoning_engine_id(app_name) reasoning_engine_id = _parse_reasoning_engine_id(app_name)
# Get session resource # Get session resource
get_session_api_response = await self.api_client.async_request( get_session_api_response = self.api_client.request(
http_method='GET', http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={}, request_dict={},
@ -141,7 +141,7 @@ class VertexAiSessionService(BaseSessionService):
last_update_time=update_timestamp, last_update_time=update_timestamp,
) )
list_events_api_response = await self.api_client.async_request( list_events_api_response = self.api_client.request(
http_method='GET', http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
request_dict={}, request_dict={},
@ -175,7 +175,7 @@ class VertexAiSessionService(BaseSessionService):
return session return session
@override @override
async def list_sessions( def list_sessions(
self, *, app_name: str, user_id: str self, *, app_name: str, user_id: str
) -> ListSessionsResponse: ) -> ListSessionsResponse:
reasoning_engine_id = _parse_reasoning_engine_id(app_name) reasoning_engine_id = _parse_reasoning_engine_id(app_name)
@ -202,23 +202,23 @@ class VertexAiSessionService(BaseSessionService):
sessions.append(session) sessions.append(session)
return ListSessionsResponse(sessions=sessions) return ListSessionsResponse(sessions=sessions)
async def delete_session( def delete_session(
self, *, app_name: str, user_id: str, session_id: str self, *, app_name: str, user_id: str, session_id: str
) -> None: ) -> None:
reasoning_engine_id = _parse_reasoning_engine_id(app_name) reasoning_engine_id = _parse_reasoning_engine_id(app_name)
await self.api_client.async_request( self.api_client.request(
http_method='DELETE', http_method='DELETE',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={}, request_dict={},
) )
@override @override
async def append_event(self, session: Session, event: Event) -> Event: def append_event(self, session: Session, event: Event) -> Event:
# Update the in-memory session. # Update the in-memory session.
await super().append_event(session=session, event=event) super().append_event(session=session, event=event)
reasoning_engine_id = _parse_reasoning_engine_id(session.app_name) reasoning_engine_id = _parse_reasoning_engine_id(session.app_name)
await self.api_client.async_request( self.api_client.request(
http_method='POST', http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent',
request_dict=_convert_event_to_json(event), request_dict=_convert_event_to_json(event),

View File

@ -129,7 +129,7 @@ class AgentTool(BaseTool):
session_service=InMemorySessionService(), session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(), memory_service=InMemoryMemoryService(),
) )
session = await runner.session_service.create_session( session = runner.session_service.create_session(
app_name=self.agent.name, app_name=self.agent.name,
user_id='tmp_user', user_id='tmp_user',
state=tool_context.state.to_dict(), state=tool_context.state.to_dict(),

View File

@ -110,11 +110,11 @@ class _TestingAgent(BaseAgent):
) )
async def _create_parent_invocation_context( def _create_parent_invocation_context(
test_name: str, agent: BaseAgent, branch: Optional[str] = None test_name: str, agent: BaseAgent, branch: Optional[str] = None
) -> InvocationContext: ) -> InvocationContext:
session_service = InMemorySessionService() session_service = InMemorySessionService()
session = await session_service.create_session( session = session_service.create_session(
app_name='test_app', user_id='test_user' app_name='test_app', user_id='test_user'
) )
return InvocationContext( return InvocationContext(
@ -134,7 +134,7 @@ def test_invalid_agent_name():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_async(request: pytest.FixtureRequest): async def test_run_async(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent') agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
@ -148,7 +148,7 @@ async def test_run_async(request: pytest.FixtureRequest):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_async_with_branch(request: pytest.FixtureRequest): async def test_run_async_with_branch(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent') agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent, branch='parent_branch' request.function.__name__, agent, branch='parent_branch'
) )
@ -170,7 +170,7 @@ async def test_run_async_before_agent_callback_noop(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
before_agent_callback=_before_agent_callback_noop, before_agent_callback=_before_agent_callback_noop,
) )
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__) spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
@ -198,7 +198,7 @@ async def test_run_async_with_async_before_agent_callback_noop(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
before_agent_callback=_async_before_agent_callback_noop, before_agent_callback=_async_before_agent_callback_noop,
) )
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__) spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
@ -226,7 +226,7 @@ async def test_run_async_before_agent_callback_bypass_agent(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
before_agent_callback=_before_agent_callback_bypass_agent, before_agent_callback=_before_agent_callback_bypass_agent,
) )
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__) spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
@ -253,7 +253,7 @@ async def test_run_async_with_async_before_agent_callback_bypass_agent(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
before_agent_callback=_async_before_agent_callback_bypass_agent, before_agent_callback=_async_before_agent_callback_bypass_agent,
) )
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__) spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
@ -394,7 +394,7 @@ async def test_before_agent_callbacks_chain(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
before_agent_callback=[mock_cb for mock_cb in mock_cbs], before_agent_callback=[mock_cb for mock_cb in mock_cbs],
) )
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
result = [e async for e in agent.run_async(parent_ctx)] result = [e async for e in agent.run_async(parent_ctx)]
@ -455,7 +455,7 @@ async def test_after_agent_callbacks_chain(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
after_agent_callback=[mock_cb for mock_cb in mock_cbs], after_agent_callback=[mock_cb for mock_cb in mock_cbs],
) )
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
result = [e async for e in agent.run_async(parent_ctx)] result = [e async for e in agent.run_async(parent_ctx)]
@ -494,7 +494,7 @@ async def test_run_async_after_agent_callback_noop(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
after_agent_callback=_after_agent_callback_noop, after_agent_callback=_after_agent_callback_noop,
) )
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback') spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
@ -520,7 +520,7 @@ async def test_run_async_with_async_after_agent_callback_noop(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
after_agent_callback=_async_after_agent_callback_noop, after_agent_callback=_async_after_agent_callback_noop,
) )
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback') spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
@ -545,7 +545,7 @@ async def test_run_async_after_agent_callback_append_reply(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
after_agent_callback=_after_agent_callback_append_agent_reply, after_agent_callback=_after_agent_callback_append_agent_reply,
) )
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
@ -570,7 +570,7 @@ async def test_run_async_with_async_after_agent_callback_append_reply(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
after_agent_callback=_async_after_agent_callback_append_agent_reply, after_agent_callback=_async_after_agent_callback_append_agent_reply,
) )
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
@ -589,7 +589,7 @@ async def test_run_async_with_async_after_agent_callback_append_reply(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_async_incomplete_agent(request: pytest.FixtureRequest): async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent') agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
@ -600,7 +600,7 @@ async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_live(request: pytest.FixtureRequest): async def test_run_live(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent') agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
@ -614,7 +614,7 @@ async def test_run_live(request: pytest.FixtureRequest):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_live_with_branch(request: pytest.FixtureRequest): async def test_run_live_with_branch(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent') agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent, branch='parent_branch' request.function.__name__, agent, branch='parent_branch'
) )
@ -629,7 +629,7 @@ async def test_run_live_with_branch(request: pytest.FixtureRequest):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_live_incomplete_agent(request: pytest.FixtureRequest): async def test_run_live_incomplete_agent(request: pytest.FixtureRequest):
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent') agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )

View File

@ -15,7 +15,7 @@
"""Unit tests for canonical_xxx fields in LlmAgent.""" """Unit tests for canonical_xxx fields in LlmAgent."""
from typing import Any from typing import Any
from typing import Optional, cast from typing import Optional
from google.adk.agents.callback_context import CallbackContext from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.invocation_context import InvocationContext
@ -30,11 +30,11 @@ from pydantic import BaseModel
import pytest import pytest
async def _create_readonly_context( def _create_readonly_context(
agent: LlmAgent, state: Optional[dict[str, Any]] = None agent: LlmAgent, state: Optional[dict[str, Any]] = None
) -> ReadonlyContext: ) -> ReadonlyContext:
session_service = InMemorySessionService() session_service = InMemorySessionService()
session = await session_service.create_session( session = session_service.create_session(
app_name='test_app', user_id='test_user', state=state app_name='test_app', user_id='test_user', state=state
) )
invocation_context = InvocationContext( invocation_context = InvocationContext(
@ -77,7 +77,7 @@ def test_canonical_model_inherit():
async def test_canonical_instruction_str(): async def test_canonical_instruction_str():
agent = LlmAgent(name='test_agent', instruction='instruction') agent = LlmAgent(name='test_agent', instruction='instruction')
ctx = await _create_readonly_context(agent) ctx = _create_readonly_context(agent)
canonical_instruction = await agent.canonical_instruction(ctx) canonical_instruction = await agent.canonical_instruction(ctx)
assert canonical_instruction == 'instruction' assert canonical_instruction == 'instruction'
@ -88,9 +88,7 @@ async def test_canonical_instruction():
return f'instruction: {ctx.state["state_var"]}' return f'instruction: {ctx.state["state_var"]}'
agent = LlmAgent(name='test_agent', instruction=_instruction_provider) agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
ctx = await _create_readonly_context( ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
agent, state={'state_var': 'state_value'}
)
canonical_instruction = await agent.canonical_instruction(ctx) canonical_instruction = await agent.canonical_instruction(ctx)
assert canonical_instruction == 'instruction: state_value' assert canonical_instruction == 'instruction: state_value'
@ -101,9 +99,7 @@ async def test_async_canonical_instruction():
return f'instruction: {ctx.state["state_var"]}' return f'instruction: {ctx.state["state_var"]}'
agent = LlmAgent(name='test_agent', instruction=_instruction_provider) agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
ctx = await _create_readonly_context( ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
agent, state={'state_var': 'state_value'}
)
canonical_instruction = await agent.canonical_instruction(ctx) canonical_instruction = await agent.canonical_instruction(ctx)
assert canonical_instruction == 'instruction: state_value' assert canonical_instruction == 'instruction: state_value'
@ -111,10 +107,10 @@ async def test_async_canonical_instruction():
async def test_canonical_global_instruction_str(): async def test_canonical_global_instruction_str():
agent = LlmAgent(name='test_agent', global_instruction='global instruction') agent = LlmAgent(name='test_agent', global_instruction='global instruction')
ctx = await _create_readonly_context(agent) ctx = _create_readonly_context(agent)
canonical_instruction = await agent.canonical_global_instruction(ctx) canonical_global_instruction = await agent.canonical_global_instruction(ctx)
assert canonical_instruction == 'global instruction' assert canonical_global_instruction == 'global instruction'
async def test_canonical_global_instruction(): async def test_canonical_global_instruction():
@ -124,9 +120,7 @@ async def test_canonical_global_instruction():
agent = LlmAgent( agent = LlmAgent(
name='test_agent', global_instruction=_global_instruction_provider name='test_agent', global_instruction=_global_instruction_provider
) )
ctx = await _create_readonly_context( ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
agent, state={'state_var': 'state_value'}
)
canonical_global_instruction = await agent.canonical_global_instruction(ctx) canonical_global_instruction = await agent.canonical_global_instruction(ctx)
assert canonical_global_instruction == 'global instruction: state_value' assert canonical_global_instruction == 'global instruction: state_value'
@ -139,14 +133,10 @@ async def test_async_canonical_global_instruction():
agent = LlmAgent( agent = LlmAgent(
name='test_agent', global_instruction=_global_instruction_provider name='test_agent', global_instruction=_global_instruction_provider
) )
ctx = await _create_readonly_context( ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
agent, state={'state_var': 'state_value'}
)
assert ( canonical_global_instruction = await agent.canonical_global_instruction(ctx)
await agent.canonical_global_instruction(ctx) assert canonical_global_instruction == 'global instruction: state_value'
== 'global instruction: state_value'
)
def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture): def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture):

View File

@ -70,11 +70,11 @@ class _TestingAgentWithEscalateAction(BaseAgent):
) )
async def _create_parent_invocation_context( def _create_parent_invocation_context(
test_name: str, agent: BaseAgent test_name: str, agent: BaseAgent
) -> InvocationContext: ) -> InvocationContext:
session_service = InMemorySessionService() session_service = InMemorySessionService()
session = await session_service.create_session( session = session_service.create_session(
app_name='test_app', user_id='test_user' app_name='test_app', user_id='test_user'
) )
return InvocationContext( return InvocationContext(
@ -95,7 +95,7 @@ async def test_run_async(request: pytest.FixtureRequest):
agent, agent,
], ],
) )
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, loop_agent request.function.__name__, loop_agent
) )
events = [e async for e in loop_agent.run_async(parent_ctx)] events = [e async for e in loop_agent.run_async(parent_ctx)]
@ -119,7 +119,7 @@ async def test_run_async_with_escalate_action(request: pytest.FixtureRequest):
name=f'{request.function.__name__}_test_loop_agent', name=f'{request.function.__name__}_test_loop_agent',
sub_agents=[non_escalating_agent, escalating_agent], sub_agents=[non_escalating_agent, escalating_agent],
) )
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, loop_agent request.function.__name__, loop_agent
) )
events = [e async for e in loop_agent.run_async(parent_ctx)] events = [e async for e in loop_agent.run_async(parent_ctx)]

View File

@ -47,11 +47,11 @@ class _TestingAgent(BaseAgent):
) )
async def _create_parent_invocation_context( def _create_parent_invocation_context(
test_name: str, agent: BaseAgent test_name: str, agent: BaseAgent
) -> InvocationContext: ) -> InvocationContext:
session_service = InMemorySessionService() session_service = InMemorySessionService()
session = await session_service.create_session( session = session_service.create_session(
app_name='test_app', user_id='test_user' app_name='test_app', user_id='test_user'
) )
return InvocationContext( return InvocationContext(
@ -76,7 +76,7 @@ async def test_run_async(request: pytest.FixtureRequest):
agent2, agent2,
], ],
) )
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, parallel_agent request.function.__name__, parallel_agent
) )
events = [e async for e in parallel_agent.run_async(parent_ctx)] events = [e async for e in parallel_agent.run_async(parent_ctx)]

View File

@ -53,11 +53,11 @@ class _TestingAgent(BaseAgent):
) )
async def _create_parent_invocation_context( def _create_parent_invocation_context(
test_name: str, agent: BaseAgent test_name: str, agent: BaseAgent
) -> InvocationContext: ) -> InvocationContext:
session_service = InMemorySessionService() session_service = InMemorySessionService()
session = await session_service.create_session( session = session_service.create_session(
app_name='test_app', user_id='test_user' app_name='test_app', user_id='test_user'
) )
return InvocationContext( return InvocationContext(
@ -79,7 +79,7 @@ async def test_run_async(request: pytest.FixtureRequest):
agent_2, agent_2,
], ],
) )
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, sequential_agent request.function.__name__, sequential_agent
) )
events = [e async for e in sequential_agent.run_async(parent_ctx)] events = [e async for e in sequential_agent.run_async(parent_ctx)]
@ -102,7 +102,7 @@ async def test_run_live(request: pytest.FixtureRequest):
agent_2, agent_2,
], ],
) )
parent_ctx = await _create_parent_invocation_context( parent_ctx = _create_parent_invocation_context(
request.function.__name__, sequential_agent request.function.__name__, sequential_agent
) )
events = [e async for e in sequential_agent.run_live(parent_ctx)] events = [e async for e in sequential_agent.run_live(parent_ctx)]

View File

@ -195,7 +195,7 @@ async def test_run_interactively_whitespace_and_exit(tmp_path: Path, monkeypatch
"""run_interactively should skip blank input, echo once, then exit.""" """run_interactively should skip blank input, echo once, then exit."""
# make a session that belongs to dummy agent # make a session that belongs to dummy agent
svc = cli.InMemorySessionService() svc = cli.InMemorySessionService()
sess = await svc.create_session(app_name="dummy", user_id="u") sess = svc.create_session(app_name="dummy", user_id="u")
artifact_service = cli.InMemoryArtifactService() artifact_service = cli.InMemoryArtifactService()
root_agent = types.SimpleNamespace(name="root") root_agent = types.SimpleNamespace(name="root")
@ -211,3 +211,46 @@ async def test_run_interactively_whitespace_and_exit(tmp_path: Path, monkeypatch
# verify: assistant echoed once with 'echo:hello' # verify: assistant echoed once with 'echo:hello'
assert any("echo:hello" in m for m in echoed) assert any("echo:hello" in m for m in echoed)
# run_cli (resume branch)
@pytest.mark.asyncio
async def test_run_cli_resume_saved_session(tmp_path: Path, fake_agent, monkeypatch: pytest.MonkeyPatch) -> None:
"""run_cli should load previous session, print its events, then re-enter interactive mode."""
parent_dir, folder = fake_agent
# stub Session.model_validate_json to return dummy session with two events
user_content = types.SimpleNamespace(parts=[types.SimpleNamespace(text="hi")])
assistant_content = types.SimpleNamespace(parts=[types.SimpleNamespace(text="hello!")])
dummy_session = types.SimpleNamespace(
id="sess",
app_name=folder,
user_id="u",
events=[
types.SimpleNamespace(author="user", content=user_content, partial=False),
types.SimpleNamespace(author="assistant", content=assistant_content, partial=False),
],
)
monkeypatch.setattr(cli.Session, "model_validate_json", staticmethod(lambda _s: dummy_session))
monkeypatch.setattr(cli.InMemorySessionService, "append_event", lambda *_a, **_k: None)
# interactive inputs: immediately 'exit'
monkeypatch.setattr("builtins.input", lambda *_a, **_k: "exit")
# collect echo output
captured: list[str] = []
monkeypatch.setattr(click, "echo", lambda m: captured.append(m))
saved_path = tmp_path / "prev.session.json"
saved_path.write_text("{}") # contents not used patched above
await cli.run_cli(
agent_parent_dir=str(parent_dir),
agent_folder_name=folder,
input_file=None,
saved_session_file=str(saved_path),
save_session=False,
)
# ④ ensure both historical messages were printed
assert any("[user]: hi" in m for m in captured)
assert any("[assistant]: hello!" in m for m in captured)

View File

@ -31,7 +31,7 @@ async def test_no_examples():
config=types.GenerateContentConfig(system_instruction=""), config=types.GenerateContentConfig(system_instruction=""),
) )
agent = Agent(model="gemini-1.5-flash", name="agent", examples=[]) agent = Agent(model="gemini-1.5-flash", name="agent", examples=[])
invocation_context = await utils.create_invocation_context( invocation_context = utils.create_invocation_context(
agent=agent, user_content="" agent=agent, user_content=""
) )
@ -69,7 +69,7 @@ async def test_agent_examples():
name="agent", name="agent",
examples=example_list, examples=example_list,
) )
invocation_context = await utils.create_invocation_context( invocation_context = utils.create_invocation_context(
agent=agent, user_content="test" agent=agent, user_content="test"
) )
@ -122,7 +122,7 @@ async def test_agent_base_example_provider():
name="agent", name="agent",
examples=provider, examples=provider,
) )
invocation_context = await utils.create_invocation_context( invocation_context = utils.create_invocation_context(
agent=agent, user_content="test" agent=agent, user_content="test"
) )

View File

@ -81,7 +81,7 @@ async def invoke_tool_with_callbacks(
before_tool_callback=before_cb, before_tool_callback=before_cb,
after_tool_callback=after_cb, after_tool_callback=after_cb,
) )
invocation_context = await utils.create_invocation_context( invocation_context = utils.create_invocation_context(
agent=agent, user_content="" agent=agent, user_content=""
) )
# Build function call event # Build function call event

View File

@ -28,7 +28,7 @@ async def test_no_description():
config=types.GenerateContentConfig(system_instruction=""), config=types.GenerateContentConfig(system_instruction=""),
) )
agent = Agent(model="gemini-1.5-flash", name="agent") agent = Agent(model="gemini-1.5-flash", name="agent")
invocation_context = await utils.create_invocation_context(agent=agent) invocation_context = utils.create_invocation_context(agent=agent)
async for _ in identity.request_processor.run_async( async for _ in identity.request_processor.run_async(
invocation_context, invocation_context,
@ -52,7 +52,7 @@ async def test_with_description():
name="agent", name="agent",
description="test description", description="test description",
) )
invocation_context = await utils.create_invocation_context(agent=agent) invocation_context = utils.create_invocation_context(agent=agent)
async for _ in identity.request_processor.run_async( async for _ in identity.request_processor.run_async(
invocation_context, invocation_context,

View File

@ -36,7 +36,7 @@ async def test_build_system_instruction():
{{customer_int }, { non-identifier-float}}, \ {{customer_int }, { non-identifier-float}}, \
{'key1': 'value1'} and {{'key2': 'value2'}}."""), {'key1': 'value1'} and {{'key2': 'value2'}}."""),
) )
invocation_context = await utils.create_invocation_context(agent=agent) invocation_context = utils.create_invocation_context(agent=agent)
invocation_context.session = Session( invocation_context.session = Session(
app_name="test_app", app_name="test_app",
user_id="test_user", user_id="test_user",
@ -73,7 +73,7 @@ async def test_function_system_instruction():
name="agent", name="agent",
instruction=build_function_instruction, instruction=build_function_instruction,
) )
invocation_context = await utils.create_invocation_context(agent=agent) invocation_context = utils.create_invocation_context(agent=agent)
invocation_context.session = Session( invocation_context.session = Session(
app_name="test_app", app_name="test_app",
user_id="test_user", user_id="test_user",
@ -111,7 +111,7 @@ async def test_async_function_system_instruction():
name="agent", name="agent",
instruction=build_function_instruction, instruction=build_function_instruction,
) )
invocation_context = await utils.create_invocation_context(agent=agent) invocation_context = utils.create_invocation_context(agent=agent)
invocation_context.session = Session( invocation_context.session = Session(
app_name="test_app", app_name="test_app",
user_id="test_user", user_id="test_user",
@ -147,7 +147,7 @@ async def test_global_system_instruction():
model="gemini-1.5-flash", model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""), config=types.GenerateContentConfig(system_instruction=""),
) )
invocation_context = await utils.create_invocation_context(agent=sub_agent) invocation_context = utils.create_invocation_context(agent=sub_agent)
invocation_context.session = Session( invocation_context.session = Session(
app_name="test_app", app_name="test_app",
user_id="test_user", user_id="test_user",
@ -189,7 +189,7 @@ async def test_function_global_system_instruction():
model="gemini-1.5-flash", model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""), config=types.GenerateContentConfig(system_instruction=""),
) )
invocation_context = await utils.create_invocation_context(agent=sub_agent) invocation_context = utils.create_invocation_context(agent=sub_agent)
invocation_context.session = Session( invocation_context.session = Session(
app_name="test_app", app_name="test_app",
user_id="test_user", user_id="test_user",
@ -231,7 +231,7 @@ async def test_async_function_global_system_instruction():
model="gemini-1.5-flash", model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""), config=types.GenerateContentConfig(system_instruction=""),
) )
invocation_context = await utils.create_invocation_context(agent=sub_agent) invocation_context = utils.create_invocation_context(agent=sub_agent)
invocation_context.session = Session( invocation_context.session = Session(
app_name="test_app", app_name="test_app",
user_id="test_user", user_id="test_user",
@ -263,7 +263,7 @@ async def test_build_system_instruction_with_namespace():
"""Use the echo_info tool to echo { customerId }, {app:key}, {user:key}, {a:key}.""" """Use the echo_info tool to echo { customerId }, {app:key}, {user:key}, {a:key}."""
), ),
) )
invocation_context = await utils.create_invocation_context(agent=agent) invocation_context = utils.create_invocation_context(agent=agent)
invocation_context.session = Session( invocation_context.session = Session(
app_name="test_app", app_name="test_app",
user_id="test_user", user_id="test_user",

View File

@ -37,28 +37,26 @@ def get_session_service(
return InMemorySessionService() return InMemorySessionService()
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
) )
async def test_get_empty_session(service_type): def test_get_empty_session(service_type):
session_service = get_session_service(service_type) session_service = get_session_service(service_type)
assert not await session_service.get_session( assert not session_service.get_session(
app_name='my_app', user_id='test_user', session_id='123' app_name='my_app', user_id='test_user', session_id='123'
) )
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
) )
async def test_create_get_session(service_type): def test_create_get_session(service_type):
session_service = get_session_service(service_type) session_service = get_session_service(service_type)
app_name = 'my_app' app_name = 'my_app'
user_id = 'test_user' user_id = 'test_user'
state = {'key': 'value'} state = {'key': 'value'}
session = await session_service.create_session( session = session_service.create_session(
app_name=app_name, user_id=user_id, state=state app_name=app_name, user_id=user_id, state=state
) )
assert session.app_name == app_name assert session.app_name == app_name
@ -66,53 +64,50 @@ async def test_create_get_session(service_type):
assert session.id assert session.id
assert session.state == state assert session.state == state
assert ( assert (
await session_service.get_session( session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id app_name=app_name, user_id=user_id, session_id=session.id
) )
== session == session
) )
session_id = session.id session_id = session.id
await session_service.delete_session( session_service.delete_session(
app_name=app_name, user_id=user_id, session_id=session_id app_name=app_name, user_id=user_id, session_id=session_id
) )
assert ( assert (
await session_service.get_session( not session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id app_name=app_name, user_id=user_id, session_id=session.id
) )
!= session == session
) )
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
) )
async def test_create_and_list_sessions(service_type): def test_create_and_list_sessions(service_type):
session_service = get_session_service(service_type) session_service = get_session_service(service_type)
app_name = 'my_app' app_name = 'my_app'
user_id = 'test_user' user_id = 'test_user'
session_ids = ['session' + str(i) for i in range(5)] session_ids = ['session' + str(i) for i in range(5)]
for session_id in session_ids: for session_id in session_ids:
await session_service.create_session( session_service.create_session(
app_name=app_name, user_id=user_id, session_id=session_id app_name=app_name, user_id=user_id, session_id=session_id
) )
list_sessions_response = await session_service.list_sessions( sessions = session_service.list_sessions(
app_name=app_name, user_id=user_id app_name=app_name, user_id=user_id
) ).sessions
sessions = list_sessions_response.sessions
for i in range(len(sessions)): for i in range(len(sessions)):
assert sessions[i].id == session_ids[i] assert sessions[i].id == session_ids[i]
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
) )
async def test_session_state(service_type): def test_session_state(service_type):
session_service = get_session_service(service_type) session_service = get_session_service(service_type)
app_name = 'my_app' app_name = 'my_app'
user_id_1 = 'user1' user_id_1 = 'user1'
@ -123,19 +118,19 @@ async def test_session_state(service_type):
state_11 = {'key11': 'value11'} state_11 = {'key11': 'value11'}
state_12 = {'key12': 'value12'} state_12 = {'key12': 'value12'}
session_11 = await session_service.create_session( session_11 = session_service.create_session(
app_name=app_name, app_name=app_name,
user_id=user_id_1, user_id=user_id_1,
state=state_11, state=state_11,
session_id=session_id_11, session_id=session_id_11,
) )
await session_service.create_session( session_service.create_session(
app_name=app_name, app_name=app_name,
user_id=user_id_1, user_id=user_id_1,
state=state_12, state=state_12,
session_id=session_id_12, session_id=session_id_12,
) )
await session_service.create_session( session_service.create_session(
app_name=app_name, user_id=user_id_2, session_id=session_id_2 app_name=app_name, user_id=user_id_2, session_id=session_id_2
) )
@ -154,7 +149,7 @@ async def test_session_state(service_type):
} }
), ),
) )
await session_service.append_event(session=session_11, event=event) session_service.append_event(session=session_11, event=event)
# User and app state is stored, temp state is filtered. # User and app state is stored, temp state is filtered.
assert session_11.state.get('app:key') == 'value' assert session_11.state.get('app:key') == 'value'
@ -162,7 +157,7 @@ async def test_session_state(service_type):
assert session_11.state.get('user:key1') == 'value1' assert session_11.state.get('user:key1') == 'value1'
assert not session_11.state.get('temp:key') assert not session_11.state.get('temp:key')
session_12 = await session_service.get_session( session_12 = session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session_id_12 app_name=app_name, user_id=user_id_1, session_id=session_id_12
) )
# After getting a new instance, the session_12 got the user and app state, # After getting a new instance, the session_12 got the user and app state,
@ -171,7 +166,7 @@ async def test_session_state(service_type):
assert not session_12.state.get('temp:key') assert not session_12.state.get('temp:key')
# The user1's state is not visible to user2, app state is visible # The user1's state is not visible to user2, app state is visible
session_2 = await session_service.get_session( session_2 = session_service.get_session(
app_name=app_name, user_id=user_id_2, session_id=session_id_2 app_name=app_name, user_id=user_id_2, session_id=session_id_2
) )
assert session_2.state.get('app:key') == 'value' assert session_2.state.get('app:key') == 'value'
@ -180,7 +175,7 @@ async def test_session_state(service_type):
assert not session_2.state.get('user:key1') assert not session_2.state.get('user:key1')
# The change to session_11 is persisted # The change to session_11 is persisted
session_11 = await session_service.get_session( session_11 = session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session_id_11 app_name=app_name, user_id=user_id_1, session_id=session_id_11
) )
assert session_11.state.get('key11') == 'value11_new' assert session_11.state.get('key11') == 'value11_new'
@ -188,11 +183,10 @@ async def test_session_state(service_type):
assert not session_11.state.get('temp:key') assert not session_11.state.get('temp:key')
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
) )
async def test_create_new_session_will_merge_states(service_type): def test_create_new_session_will_merge_states(service_type):
session_service = get_session_service(service_type) session_service = get_session_service(service_type)
app_name = 'my_app' app_name = 'my_app'
user_id = 'user' user_id = 'user'
@ -200,7 +194,7 @@ async def test_create_new_session_will_merge_states(service_type):
session_id_2 = 'session2' session_id_2 = 'session2'
state_1 = {'key1': 'value1'} state_1 = {'key1': 'value1'}
session_1 = await session_service.create_session( session_1 = session_service.create_session(
app_name=app_name, user_id=user_id, state=state_1, session_id=session_id_1 app_name=app_name, user_id=user_id, state=state_1, session_id=session_id_1
) )
@ -216,7 +210,7 @@ async def test_create_new_session_will_merge_states(service_type):
} }
), ),
) )
await session_service.append_event(session=session_1, event=event) session_service.append_event(session=session_1, event=event)
# User and app state is stored, temp state is filtered. # User and app state is stored, temp state is filtered.
assert session_1.state.get('app:key') == 'value' assert session_1.state.get('app:key') == 'value'
@ -224,7 +218,7 @@ async def test_create_new_session_will_merge_states(service_type):
assert session_1.state.get('user:key1') == 'value1' assert session_1.state.get('user:key1') == 'value1'
assert not session_1.state.get('temp:key') assert not session_1.state.get('temp:key')
session_2 = await session_service.create_session( session_2 = session_service.create_session(
app_name=app_name, user_id=user_id, state={}, session_id=session_id_2 app_name=app_name, user_id=user_id, state={}, session_id=session_id_2
) )
# Session 2 has the persisted states # Session 2 has the persisted states
@ -234,18 +228,15 @@ async def test_create_new_session_will_merge_states(service_type):
assert not session_2.state.get('temp:key') assert not session_2.state.get('temp:key')
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
) )
async def test_append_event_bytes(service_type): def test_append_event_bytes(service_type):
session_service = get_session_service(service_type) session_service = get_session_service(service_type)
app_name = 'my_app' app_name = 'my_app'
user_id = 'user' user_id = 'user'
session = await session_service.create_session( session = session_service.create_session(app_name=app_name, user_id=user_id)
app_name=app_name, user_id=user_id
)
event = Event( event = Event(
invocation_id='invocation', invocation_id='invocation',
author='user', author='user',
@ -258,34 +249,30 @@ async def test_append_event_bytes(service_type):
], ],
), ),
) )
await session_service.append_event(session=session, event=event) session_service.append_event(session=session, event=event)
assert session.events[0].content.parts[0] == types.Part.from_bytes( assert session.events[0].content.parts[0] == types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png' data=b'test_image_data', mime_type='image/png'
) )
session = await session_service.get_session( events = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id app_name=app_name, user_id=user_id, session_id=session.id
) ).events
events = session.events
assert len(events) == 1 assert len(events) == 1
assert events[0].content.parts[0] == types.Part.from_bytes( assert events[0].content.parts[0] == types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png' data=b'test_image_data', mime_type='image/png'
) )
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
) )
async def test_append_event_complete(service_type): def test_append_event_complete(service_type):
session_service = get_session_service(service_type) session_service = get_session_service(service_type)
app_name = 'my_app' app_name = 'my_app'
user_id = 'user' user_id = 'user'
session = await session_service.create_session( session = session_service.create_session(app_name=app_name, user_id=user_id)
app_name=app_name, user_id=user_id
)
event = Event( event = Event(
invocation_id='invocation', invocation_id='invocation',
author='user', author='user',
@ -304,73 +291,65 @@ async def test_append_event_complete(service_type):
error_message='error_message', error_message='error_message',
interrupted=True, interrupted=True,
) )
await session_service.append_event(session=session, event=event) session_service.append_event(session=session, event=event)
assert ( assert (
await session_service.get_session( session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id app_name=app_name, user_id=user_id, session_id=session.id
) )
== session == session
) )
@pytest.mark.asyncio
@pytest.mark.parametrize('service_type', [SessionServiceType.IN_MEMORY]) @pytest.mark.parametrize('service_type', [SessionServiceType.IN_MEMORY])
async def test_get_session_with_config(service_type): def test_get_session_with_config(service_type):
session_service = get_session_service(service_type) session_service = get_session_service(service_type)
app_name = 'my_app' app_name = 'my_app'
user_id = 'user' user_id = 'user'
num_test_events = 5 num_test_events = 5
session = await session_service.create_session( session = session_service.create_session(app_name=app_name, user_id=user_id)
app_name=app_name, user_id=user_id
)
for i in range(1, num_test_events + 1): for i in range(1, num_test_events + 1):
event = Event(author='user', timestamp=i) event = Event(author='user', timestamp=i)
await session_service.append_event(session, event) session_service.append_event(session, event)
# No config, expect all events to be returned. # No config, expect all events to be returned.
session = await session_service.get_session( events = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id app_name=app_name, user_id=user_id, session_id=session.id
) ).events
events = session.events
assert len(events) == num_test_events assert len(events) == num_test_events
# Only expect the most recent 3 events. # Only expect the most recent 3 events.
num_recent_events = 3 num_recent_events = 3
config = GetSessionConfig(num_recent_events=num_recent_events) config = GetSessionConfig(num_recent_events=num_recent_events)
session = await session_service.get_session( events = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config app_name=app_name, user_id=user_id, session_id=session.id, config=config
) ).events
events = session.events
assert len(events) == num_recent_events assert len(events) == num_recent_events
assert events[0].timestamp == num_test_events - num_recent_events + 1 assert events[0].timestamp == num_test_events - num_recent_events + 1
# Only expect events after timestamp 4.0 (inclusive), i.e., 2 events. # Only expect events after timestamp 4.0 (inclusive), i.e., 2 events.
after_timestamp = 4.0 after_timestamp = 4.0
config = GetSessionConfig(after_timestamp=after_timestamp) config = GetSessionConfig(after_timestamp=after_timestamp)
session = await session_service.get_session( events = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config app_name=app_name, user_id=user_id, session_id=session.id, config=config
) ).events
events = session.events
assert len(events) == num_test_events - after_timestamp + 1 assert len(events) == num_test_events - after_timestamp + 1
assert events[0].timestamp == after_timestamp assert events[0].timestamp == after_timestamp
# Expect no events if none are > after_timestamp. # Expect no events if none are > after_timestamp.
way_after_timestamp = num_test_events * 10 way_after_timestamp = num_test_events * 10
config = GetSessionConfig(after_timestamp=way_after_timestamp) config = GetSessionConfig(after_timestamp=way_after_timestamp)
session = await session_service.get_session( events = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config app_name=app_name, user_id=user_id, session_id=session.id, config=config
) ).events
assert not session.events assert len(events) == 0
# Both filters applied, i.e., of 3 most recent events, only 2 are after # Both filters applied, i.e., of 3 most recent events, only 2 are after
# timestamp 4.0, so expect 2 events. # timestamp 4.0, so expect 2 events.
config = GetSessionConfig( config = GetSessionConfig(
after_timestamp=after_timestamp, num_recent_events=num_recent_events after_timestamp=after_timestamp, num_recent_events=num_recent_events
) )
session = await session_service.get_session( events = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config app_name=app_name, user_id=user_id, session_id=session.id, config=config
) ).events
events = session.events
assert len(events) == num_test_events - after_timestamp + 1 assert len(events) == num_test_events - after_timestamp + 1

View File

@ -15,7 +15,7 @@
import re import re
import this import this
from typing import Any from typing import Any
import uuid
from dateutil.parser import isoparse from dateutil.parser import isoparse
from google.adk.events import Event from google.adk.events import Event
from google.adk.events import EventActions from google.adk.events import EventActions
@ -124,9 +124,7 @@ class MockApiClient:
this.session_dict: dict[str, Any] = {} this.session_dict: dict[str, Any] = {}
this.event_dict: dict[str, list[Any]] = {} this.event_dict: dict[str, list[Any]] = {}
async def async_request( def request(self, http_method: str, path: str, request_dict: dict[str, Any]):
self, http_method: str, path: str, request_dict: dict[str, Any]
):
"""Mocks the API Client request method.""" """Mocks the API Client request method."""
if http_method == 'GET': if http_method == 'GET':
if re.match(SESSION_REGEX, path): if re.match(SESSION_REGEX, path):
@ -212,52 +210,46 @@ def mock_vertex_ai_session_service():
return service return service
@pytest.mark.asyncio def test_get_empty_session():
async def test_get_empty_session():
session_service = mock_vertex_ai_session_service() session_service = mock_vertex_ai_session_service()
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
assert await session_service.get_session( assert session_service.get_session(
app_name='123', user_id='user', session_id='0' app_name='123', user_id='user', session_id='0'
) )
assert str(excinfo.value) == 'Session not found: 0' assert str(excinfo.value) == 'Session not found: 0'
@pytest.mark.asyncio def test_get_and_delete_session():
async def test_get_and_delete_session():
session_service = mock_vertex_ai_session_service() session_service = mock_vertex_ai_session_service()
assert ( assert (
await session_service.get_session( session_service.get_session(
app_name='123', user_id='user', session_id='1' app_name='123', user_id='user', session_id='1'
) )
== MOCK_SESSION == MOCK_SESSION
) )
await session_service.delete_session( session_service.delete_session(app_name='123', user_id='user', session_id='1')
app_name='123', user_id='user', session_id='1'
)
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
assert await session_service.get_session( assert session_service.get_session(
app_name='123', user_id='user', session_id='1' app_name='123', user_id='user', session_id='1'
) )
assert str(excinfo.value) == 'Session not found: 1' assert str(excinfo.value) == 'Session not found: 1'
@pytest.mark.asyncio def test_list_sessions():
async def test_list_sessions():
session_service = mock_vertex_ai_session_service() session_service = mock_vertex_ai_session_service()
sessions = await session_service.list_sessions(app_name='123', user_id='user') sessions = session_service.list_sessions(app_name='123', user_id='user')
assert len(sessions.sessions) == 2 assert len(sessions.sessions) == 2
assert sessions.sessions[0].id == '1' assert sessions.sessions[0].id == '1'
assert sessions.sessions[1].id == '2' assert sessions.sessions[1].id == '2'
@pytest.mark.asyncio def test_create_session():
async def test_create_session():
session_service = mock_vertex_ai_session_service() session_service = mock_vertex_ai_session_service()
state = {'key': 'value'} state = {'key': 'value'}
session = await session_service.create_session( session = session_service.create_session(
app_name='123', user_id='user', state=state app_name='123', user_id='user', state=state
) )
assert session.state == state assert session.state == state
@ -266,17 +258,16 @@ async def test_create_session():
assert session.last_update_time is not None assert session.last_update_time is not None
session_id = session.id session_id = session.id
assert session == await session_service.get_session( assert session == session_service.get_session(
app_name='123', user_id='user', session_id=session_id app_name='123', user_id='user', session_id=session_id
) )
@pytest.mark.asyncio def test_create_session_with_custom_session_id():
async def test_create_session_with_custom_session_id():
session_service = mock_vertex_ai_session_service() session_service = mock_vertex_ai_session_service()
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
await session_service.create_session( session_service.create_session(
app_name='123', user_id='user', session_id='1' app_name='123', user_id='user', session_id='1'
) )
assert str(excinfo.value) == ( assert str(excinfo.value) == (

View File

@ -37,9 +37,9 @@ class _TestingTool(BaseTool):
return self.declaration return self.declaration
async def _create_tool_context() -> ToolContext: def _create_tool_context() -> ToolContext:
session_service = InMemorySessionService() session_service = InMemorySessionService()
session = await session_service.create_session( session = session_service.create_session(
app_name='test_app', user_id='test_user' app_name='test_app', user_id='test_user'
) )
agent = SequentialAgent(name='test_agent') agent = SequentialAgent(name='test_agent')
@ -55,7 +55,7 @@ async def _create_tool_context() -> ToolContext:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_llm_request_no_declaration(): async def test_process_llm_request_no_declaration():
tool = _TestingTool() tool = _TestingTool()
tool_context = await _create_tool_context() tool_context = _create_tool_context()
llm_request = LlmRequest() llm_request = LlmRequest()
await tool.process_llm_request( await tool.process_llm_request(
@ -77,7 +77,7 @@ async def test_process_llm_request_with_declaration():
) )
tool = _TestingTool(declaration) tool = _TestingTool(declaration)
llm_request = LlmRequest() llm_request = LlmRequest()
tool_context = await _create_tool_context() tool_context = _create_tool_context()
await tool.process_llm_request( await tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request tool_context=tool_context, llm_request=llm_request
@ -102,7 +102,7 @@ async def test_process_llm_request_with_builtin_tool():
tools=[types.Tool(google_search=types.GoogleSearch())] tools=[types.Tool(google_search=types.GoogleSearch())]
) )
) )
tool_context = await _create_tool_context() tool_context = _create_tool_context()
await tool.process_llm_request( await tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request tool_context=tool_context, llm_request=llm_request
@ -131,7 +131,7 @@ async def test_process_llm_request_with_builtin_tool_and_another_declaration():
] ]
) )
) )
tool_context = await _create_tool_context() tool_context = _create_tool_context()
await tool.process_llm_request( await tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request tool_context=tool_context, llm_request=llm_request

View File

@ -56,7 +56,7 @@ class ModelContent(types.Content):
super().__init__(role='model', parts=parts) super().__init__(role='model', parts=parts)
async def create_invocation_context(agent: Agent, user_content: str = ''): def create_invocation_context(agent: Agent, user_content: str = ''):
invocation_id = 'test_id' invocation_id = 'test_id'
artifact_service = InMemoryArtifactService() artifact_service = InMemoryArtifactService()
session_service = InMemorySessionService() session_service = InMemorySessionService()
@ -67,7 +67,7 @@ async def create_invocation_context(agent: Agent, user_content: str = ''):
memory_service=memory_service, memory_service=memory_service,
invocation_id=invocation_id, invocation_id=invocation_id,
agent=agent, agent=agent,
session=await session_service.create_session( session=session_service.create_session(
app_name='test_app', user_id='test_user' app_name='test_app', user_id='test_user'
), ),
user_content=types.Content( user_content=types.Content(
@ -141,7 +141,7 @@ class TestInMemoryRunner(AfInMemoryRunner):
self, new_message: types.ContentUnion self, new_message: types.ContentUnion
) -> list[Event]: ) -> list[Event]:
session = await self.session_service.create_session( session = self.session_service.create_session(
app_name='InMemoryRunner', user_id='test_user' app_name='InMemoryRunner', user_id='test_user'
) )
collected_events = [] collected_events = []
@ -172,23 +172,15 @@ class InMemoryRunner:
session_service=InMemorySessionService(), session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(), memory_service=InMemoryMemoryService(),
) )
self.session_id = None self.session_id = self.runner.session_service.create_session(
app_name='test_app', user_id='test_user'
).id
@property @property
def session(self) -> Session: def session(self) -> Session:
if not self.session_id: return self.runner.session_service.get_session(
session = asyncio.run(
self.runner.session_service.create_session(
app_name='test_app', user_id='test_user'
)
)
self.session_id = session.id
return session
return asyncio.run(
self.runner.session_service.get_session(
app_name='test_app', user_id='test_user', session_id=self.session_id app_name='test_app', user_id='test_user', session_id=self.session_id
) )
)
def run(self, new_message: types.ContentUnion) -> list[Event]: def run(self, new_message: types.ContentUnion) -> list[Event]:
return list( return list(
@ -202,9 +194,9 @@ class InMemoryRunner:
def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]: def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]:
collected_responses = [] collected_responses = []
async def consume_responses(session: Session): async def consume_responses():
run_res = self.runner.run_live( run_res = self.runner.run_live(
session=session, session=self.session,
live_request_queue=live_request_queue, live_request_queue=live_request_queue,
) )
@ -215,8 +207,7 @@ class InMemoryRunner:
return return
try: try:
session = self.session asyncio.run(consume_responses())
asyncio.run(consume_responses(session))
except asyncio.TimeoutError: except asyncio.TimeoutError:
print('Returning any partial results collected so far.') print('Returning any partial results collected so far.')