ADK changes

PiperOrigin-RevId: 759259620
This commit is contained in:
Google Team Member 2025-05-15 12:46:12 -07:00 committed by Copybara-Service
parent 1804ca39a6
commit 05917cabbd
23 changed files with 264 additions and 268 deletions

View File

@ -55,7 +55,7 @@ async def run_input_file(
input_file = InputFile.model_validate_json(f.read()) input_file = InputFile.model_validate_json(f.read())
input_file.state['_time'] = datetime.now() input_file.state['_time'] = datetime.now()
session = session_service.create_session( session = await session_service.create_session(
app_name=app_name, user_id=user_id, state=input_file.state app_name=app_name, user_id=user_id, state=input_file.state
) )
for query in input_file.queries: for query in input_file.queries:
@ -130,7 +130,7 @@ async def run_cli(
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name) agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
agent_module = importlib.import_module(agent_folder_name) agent_module = importlib.import_module(agent_folder_name)
user_id = 'test_user' user_id = 'test_user'
session = session_service.create_session( session = await session_service.create_session(
app_name=agent_folder_name, user_id=user_id app_name=agent_folder_name, user_id=user_id
) )
root_agent = agent_module.agent.root_agent root_agent = agent_module.agent.root_agent
@ -145,14 +145,12 @@ async def run_cli(
input_path=input_file, input_path=input_file,
) )
elif saved_session_file: elif saved_session_file:
loaded_session = None
with open(saved_session_file, 'r') as f: with open(saved_session_file, 'r') as f:
loaded_session = Session.model_validate_json(f.read()) loaded_session = Session.model_validate_json(f.read())
if loaded_session: if loaded_session:
for event in loaded_session.events: for event in loaded_session.events:
session_service.append_event(session, event) await session_service.append_event(session, event)
content = event.content content = event.content
if not content or not content.parts or not content.parts[0].text: if not content or not content.parts or not content.parts[0].text:
continue continue
@ -181,7 +179,7 @@ async def run_cli(
session_path = f'{agent_module_path}/{session_id}.session.json' session_path = f'{agent_module_path}/{session_id}.session.json'
# Fetch the session again to get all the details. # Fetch the session again to get all the details.
session = session_service.get_session( session = await session_service.get_session(
app_name=session.app_name, app_name=session.app_name,
user_id=session.user_id, user_id=session.user_id,
session_id=session.id, session_id=session.id,

View File

@ -333,10 +333,12 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}", "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
response_model_exclude_none=True, response_model_exclude_none=True,
) )
def get_session(app_name: str, user_id: str, session_id: str) -> Session: async def get_session(
app_name: str, user_id: str, session_id: str
) -> Session:
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name app_name = agent_engine_id if agent_engine_id else app_name
session = session_service.get_session( session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id app_name=app_name, user_id=user_id, session_id=session_id
) )
if not session: if not session:
@ -347,14 +349,15 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions", "/apps/{app_name}/users/{user_id}/sessions",
response_model_exclude_none=True, response_model_exclude_none=True,
) )
def list_sessions(app_name: str, user_id: str) -> list[Session]: async def list_sessions(app_name: str, user_id: str) -> list[Session]:
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name app_name = agent_engine_id if agent_engine_id else app_name
list_sessions_response = await session_service.list_sessions(
app_name=app_name, user_id=user_id
)
return [ return [
session session
for session in session_service.list_sessions( for session in list_sessions_response.sessions
app_name=app_name, user_id=user_id
).sessions
# Remove sessions that were generated as a part of Eval. # Remove sessions that were generated as a part of Eval.
if not session.id.startswith(EVAL_SESSION_ID_PREFIX) if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
] ]
@ -363,7 +366,7 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}", "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
response_model_exclude_none=True, response_model_exclude_none=True,
) )
def create_session_with_id( async def create_session_with_id(
app_name: str, app_name: str,
user_id: str, user_id: str,
session_id: str, session_id: str,
@ -372,7 +375,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name app_name = agent_engine_id if agent_engine_id else app_name
if ( if (
session_service.get_session( await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id app_name=app_name, user_id=user_id, session_id=session_id
) )
is not None is not None
@ -382,7 +385,7 @@ def get_fast_api_app(
status_code=400, detail=f"Session already exists: {session_id}" status_code=400, detail=f"Session already exists: {session_id}"
) )
logger.info("New session created: %s", session_id) logger.info("New session created: %s", session_id)
return session_service.create_session( return await session_service.create_session(
app_name=app_name, user_id=user_id, state=state, session_id=session_id app_name=app_name, user_id=user_id, state=state, session_id=session_id
) )
@ -390,7 +393,7 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions", "/apps/{app_name}/users/{user_id}/sessions",
response_model_exclude_none=True, response_model_exclude_none=True,
) )
def create_session( async def create_session(
app_name: str, app_name: str,
user_id: str, user_id: str,
state: Optional[dict[str, Any]] = None, state: Optional[dict[str, Any]] = None,
@ -398,7 +401,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name app_name = agent_engine_id if agent_engine_id else app_name
logger.info("New session created") logger.info("New session created")
return session_service.create_session( return await session_service.create_session(
app_name=app_name, user_id=user_id, state=state app_name=app_name, user_id=user_id, state=state
) )
@ -442,7 +445,7 @@ def get_fast_api_app(
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
): ):
# Get the session # Get the session
session = session_service.get_session( session = await session_service.get_session(
app_name=app_name, user_id=req.user_id, session_id=req.session_id app_name=app_name, user_id=req.user_id, session_id=req.session_id
) )
assert session, "Session not found." assert session, "Session not found."
@ -530,7 +533,7 @@ def get_fast_api_app(
session_id=eval_case_result.session_id, session_id=eval_case_result.session_id,
) )
) )
eval_case_result.session_details = session_service.get_session( eval_case_result.session_details = await session_service.get_session(
app_name=app_name, app_name=app_name,
user_id=eval_case_result.user_id, user_id=eval_case_result.user_id,
session_id=eval_case_result.session_id, session_id=eval_case_result.session_id,
@ -615,10 +618,10 @@ def get_fast_api_app(
return eval_result_files return eval_result_files
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}") @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
def delete_session(app_name: str, user_id: str, session_id: str): async def delete_session(app_name: str, user_id: str, session_id: str):
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name app_name = agent_engine_id if agent_engine_id else app_name
session_service.delete_session( await session_service.delete_session(
app_name=app_name, user_id=user_id, session_id=session_id app_name=app_name, user_id=user_id, session_id=session_id
) )
@ -713,7 +716,7 @@ def get_fast_api_app(
async def agent_run(req: AgentRunRequest) -> list[Event]: async def agent_run(req: AgentRunRequest) -> list[Event]:
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name app_id = agent_engine_id if agent_engine_id else req.app_name
session = session_service.get_session( session = await session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id app_name=app_id, user_id=req.user_id, session_id=req.session_id
) )
if not session: if not session:
@ -735,7 +738,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name app_id = agent_engine_id if agent_engine_id else req.app_name
# SSE endpoint # SSE endpoint
session = session_service.get_session( session = await session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id app_name=app_id, user_id=req.user_id, session_id=req.session_id
) )
if not session: if not session:
@ -776,7 +779,7 @@ def get_fast_api_app(
): ):
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name app_id = agent_engine_id if agent_engine_id else app_name
session = session_service.get_session( session = await session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id app_name=app_id, user_id=user_id, session_id=session_id
) )
session_events = session.events if session else [] session_events = session.events if session else []
@ -833,7 +836,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name app_id = agent_engine_id if agent_engine_id else app_name
session = session_service.get_session( session = await session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id app_name=app_id, user_id=user_id, session_id=session_id
) )
if not session: if not session:

View File

@ -126,7 +126,7 @@ class EvaluationGenerator:
user_id = initial_session.user_id if initial_session else "test_user_id" user_id = initial_session.user_id if initial_session else "test_user_id"
session_id = session_id if session_id else str(uuid.uuid4()) session_id = session_id if session_id else str(uuid.uuid4())
_ = session_service.create_session( _ = await session_service.create_session(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
state=initial_session.state if initial_session else {}, state=initial_session.state if initial_session else {},

View File

@ -173,7 +173,7 @@ class Runner:
The events generated by the agent. The events generated by the agent.
""" """
with tracer.start_as_current_span('invocation'): with tracer.start_as_current_span('invocation'):
session = self.session_service.get_session( session = await self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id app_name=self.app_name, user_id=user_id, session_id=session_id
) )
if not session: if not session:
@ -197,7 +197,7 @@ class Runner:
invocation_context.agent = self._find_agent_to_run(session, root_agent) invocation_context.agent = self._find_agent_to_run(session, root_agent)
async for event in invocation_context.agent.run_async(invocation_context): async for event in invocation_context.agent.run_async(invocation_context):
if not event.partial: if not event.partial:
self.session_service.append_event(session=session, event=event) await self.session_service.append_event(session=session, event=event)
yield event yield event
async def _append_new_message_to_session( async def _append_new_message_to_session(
@ -242,7 +242,7 @@ class Runner:
author='user', author='user',
content=new_message, content=new_message,
) )
self.session_service.append_event(session=session, event=event) await self.session_service.append_event(session=session, event=event)
async def run_live( async def run_live(
self, self,
@ -324,7 +324,7 @@ class Runner:
) )
async for event in invocation_context.agent.run_live(invocation_context): async for event in invocation_context.agent.run_live(invocation_context):
self.session_service.append_event(session=session, event=event) await self.session_service.append_event(session=session, event=event)
yield event yield event
async def close_session(self, session: Session): async def close_session(self, session: Session):
@ -335,7 +335,7 @@ class Runner:
""" """
if self.memory_service: if self.memory_service:
await self.memory_service.add_session_to_memory(session) await self.memory_service.add_session_to_memory(session)
self.session_service.close_session(session=session) await self.session_service.close_session(session=session)
def _find_agent_to_run( def _find_agent_to_run(
self, session: Session, root_agent: BaseAgent self, session: Session, root_agent: BaseAgent

View File

@ -47,7 +47,7 @@ class BaseSessionService(abc.ABC):
""" """
@abc.abstractmethod @abc.abstractmethod
def create_session( async def create_session(
self, self,
*, *,
app_name: str, app_name: str,
@ -67,10 +67,9 @@ class BaseSessionService(abc.ABC):
Returns: Returns:
session: The newly created session instance. session: The newly created session instance.
""" """
pass
@abc.abstractmethod @abc.abstractmethod
def get_session( async def get_session(
self, self,
*, *,
app_name: str, app_name: str,
@ -79,28 +78,24 @@ class BaseSessionService(abc.ABC):
config: Optional[GetSessionConfig] = None, config: Optional[GetSessionConfig] = None,
) -> Optional[Session]: ) -> Optional[Session]:
"""Gets a session.""" """Gets a session."""
pass
@abc.abstractmethod @abc.abstractmethod
def list_sessions( async def list_sessions(
self, *, app_name: str, user_id: str self, *, app_name: str, user_id: str
) -> ListSessionsResponse: ) -> ListSessionsResponse:
"""Lists all the sessions.""" """Lists all the sessions."""
pass
@abc.abstractmethod @abc.abstractmethod
def delete_session( async def delete_session(
self, *, app_name: str, user_id: str, session_id: str self, *, app_name: str, user_id: str, session_id: str
) -> None: ) -> None:
"""Deletes a session.""" """Deletes a session."""
pass
def close_session(self, *, session: Session): async def close_session(self, *, session: Session):
"""Closes a session.""" """Closes a session."""
# TODO: determine whether we want to finalize the session here. # TODO: determine whether we want to finalize the session here.
pass
def append_event(self, session: Session, event: Event) -> Event: async def append_event(self, session: Session, event: Event) -> Event:
"""Appends an event to a session object.""" """Appends an event to a session object."""
if event.partial: if event.partial:
return event return event

View File

@ -283,7 +283,7 @@ class DatabaseSessionService(BaseSessionService):
Base.metadata.create_all(self.db_engine) Base.metadata.create_all(self.db_engine)
@override @override
def create_session( async def create_session(
self, self,
*, *,
app_name: str, app_name: str,
@ -357,7 +357,7 @@ class DatabaseSessionService(BaseSessionService):
return session return session
@override @override
def get_session( async def get_session(
self, self,
*, *,
app_name: str, app_name: str,
@ -431,7 +431,7 @@ class DatabaseSessionService(BaseSessionService):
return session return session
@override @override
def list_sessions( async def list_sessions(
self, *, app_name: str, user_id: str self, *, app_name: str, user_id: str
) -> ListSessionsResponse: ) -> ListSessionsResponse:
with self.DatabaseSessionFactory() as sessionFactory: with self.DatabaseSessionFactory() as sessionFactory:
@ -454,7 +454,7 @@ class DatabaseSessionService(BaseSessionService):
return ListSessionsResponse(sessions=sessions) return ListSessionsResponse(sessions=sessions)
@override @override
def delete_session( async def delete_session(
self, app_name: str, user_id: str, session_id: str self, app_name: str, user_id: str, session_id: str
) -> None: ) -> None:
with self.DatabaseSessionFactory() as sessionFactory: with self.DatabaseSessionFactory() as sessionFactory:
@ -467,7 +467,7 @@ class DatabaseSessionService(BaseSessionService):
sessionFactory.commit() sessionFactory.commit()
@override @override
def append_event(self, session: Session, event: Event) -> Event: async def append_event(self, session: Session, event: Event) -> Event:
logger.info(f"Append event: {event} to session {session.id}") logger.info(f"Append event: {event} to session {session.id}")
if event.partial: if event.partial:
@ -552,9 +552,10 @@ class DatabaseSessionService(BaseSessionService):
session.last_update_time = storage_session.update_time.timestamp() session.last_update_time = storage_session.update_time.timestamp()
# Also update the in-memory session # Also update the in-memory session
super().append_event(session=session, event=event) await super().append_event(session=session, event=event)
return event return event
def convert_event(event: StorageEvent) -> Event: def convert_event(event: StorageEvent) -> Event:
"""Converts a storage event to an event.""" """Converts a storage event to an event."""
return Event( return Event(

View File

@ -44,7 +44,7 @@ class InMemorySessionService(BaseSessionService):
self.app_state: dict[str, dict[str, Any]] = {} self.app_state: dict[str, dict[str, Any]] = {}
@override @override
def create_session( async def create_session(
self, self,
*, *,
app_name: str, app_name: str,
@ -106,7 +106,7 @@ class InMemorySessionService(BaseSessionService):
return self._merge_state(app_name, user_id, copied_session) return self._merge_state(app_name, user_id, copied_session)
@override @override
def get_session( async def get_session(
self, self,
*, *,
app_name: str, app_name: str,
@ -193,7 +193,7 @@ class InMemorySessionService(BaseSessionService):
return copied_session return copied_session
@override @override
def list_sessions( async def list_sessions(
self, *, app_name: str, user_id: str self, *, app_name: str, user_id: str
) -> ListSessionsResponse: ) -> ListSessionsResponse:
return self._list_sessions_impl(app_name=app_name, user_id=user_id) return self._list_sessions_impl(app_name=app_name, user_id=user_id)
@ -221,7 +221,7 @@ class InMemorySessionService(BaseSessionService):
sessions_without_events.append(copied_session) sessions_without_events.append(copied_session)
return ListSessionsResponse(sessions=sessions_without_events) return ListSessionsResponse(sessions=sessions_without_events)
def delete_session( async def delete_session(
self, *, app_name: str, user_id: str, session_id: str self, *, app_name: str, user_id: str, session_id: str
) -> None: ) -> None:
self._delete_session_impl( self._delete_session_impl(
@ -250,16 +250,9 @@ class InMemorySessionService(BaseSessionService):
self.sessions[app_name][user_id].pop(session_id) self.sessions[app_name][user_id].pop(session_id)
@override @override
def append_event(self, session: Session, event: Event) -> Event: async def append_event(self, session: Session, event: Event) -> Event:
return self._append_event_impl(session=session, event=event)
def append_event_sync(self, session: Session, event: Event) -> Event:
logger.warning('Deprecated. Please migrate to the async method.')
return self._append_event_impl(session=session, event=event)
def _append_event_impl(self, session: Session, event: Event) -> Event:
# Update the in-memory session. # Update the in-memory session.
super().append_event(session=session, event=event) await super().append_event(session=session, event=event)
session.last_update_time = event.timestamp session.last_update_time = event.timestamp
# Update the storage session # Update the storage session
@ -286,7 +279,7 @@ class InMemorySessionService(BaseSessionService):
] = event.actions.state_delta[key] ] = event.actions.state_delta[key]
storage_session = self.sessions[app_name][user_id].get(session_id) storage_session = self.sessions[app_name][user_id].get(session_id)
super().append_event(session=storage_session, event=event) await super().append_event(session=storage_session, event=event)
storage_session.last_update_time = event.timestamp storage_session.last_update_time = event.timestamp

View File

@ -48,7 +48,7 @@ class VertexAiSessionService(BaseSessionService):
self.api_client = client._api_client self.api_client = client._api_client
@override @override
def create_session( async def create_session(
self, self,
*, *,
app_name: str, app_name: str,
@ -68,7 +68,7 @@ class VertexAiSessionService(BaseSessionService):
if state: if state:
session_json_dict['session_state'] = state session_json_dict['session_state'] = state
api_response = self.api_client.request( api_response = await self.api_client.async_request(
http_method='POST', http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions', path=f'reasoningEngines/{reasoning_engine_id}/sessions',
request_dict=session_json_dict, request_dict=session_json_dict,
@ -80,7 +80,7 @@ class VertexAiSessionService(BaseSessionService):
max_retry_attempt = 5 max_retry_attempt = 5
while max_retry_attempt >= 0: while max_retry_attempt >= 0:
lro_response = self.api_client.request( lro_response = await self.api_client.async_request(
http_method='GET', http_method='GET',
path=f'operations/{operation_id}', path=f'operations/{operation_id}',
request_dict={}, request_dict={},
@ -93,7 +93,7 @@ class VertexAiSessionService(BaseSessionService):
max_retry_attempt -= 1 max_retry_attempt -= 1
# Get session resource # Get session resource
get_session_api_response = self.api_client.request( get_session_api_response = await self.api_client.async_request(
http_method='GET', http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={}, request_dict={},
@ -112,7 +112,7 @@ class VertexAiSessionService(BaseSessionService):
return session return session
@override @override
def get_session( async def get_session(
self, self,
*, *,
app_name: str, app_name: str,
@ -123,7 +123,7 @@ class VertexAiSessionService(BaseSessionService):
reasoning_engine_id = _parse_reasoning_engine_id(app_name) reasoning_engine_id = _parse_reasoning_engine_id(app_name)
# Get session resource # Get session resource
get_session_api_response = self.api_client.request( get_session_api_response = await self.api_client.async_request(
http_method='GET', http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={}, request_dict={},
@ -141,7 +141,7 @@ class VertexAiSessionService(BaseSessionService):
last_update_time=update_timestamp, last_update_time=update_timestamp,
) )
list_events_api_response = self.api_client.request( list_events_api_response = await self.api_client.async_request(
http_method='GET', http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
request_dict={}, request_dict={},
@ -175,7 +175,7 @@ class VertexAiSessionService(BaseSessionService):
return session return session
@override @override
def list_sessions( async def list_sessions(
self, *, app_name: str, user_id: str self, *, app_name: str, user_id: str
) -> ListSessionsResponse: ) -> ListSessionsResponse:
reasoning_engine_id = _parse_reasoning_engine_id(app_name) reasoning_engine_id = _parse_reasoning_engine_id(app_name)
@ -202,23 +202,23 @@ class VertexAiSessionService(BaseSessionService):
sessions.append(session) sessions.append(session)
return ListSessionsResponse(sessions=sessions) return ListSessionsResponse(sessions=sessions)
def delete_session( async def delete_session(
self, *, app_name: str, user_id: str, session_id: str self, *, app_name: str, user_id: str, session_id: str
) -> None: ) -> None:
reasoning_engine_id = _parse_reasoning_engine_id(app_name) reasoning_engine_id = _parse_reasoning_engine_id(app_name)
self.api_client.request( await self.api_client.async_request(
http_method='DELETE', http_method='DELETE',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={}, request_dict={},
) )
@override @override
def append_event(self, session: Session, event: Event) -> Event: async def append_event(self, session: Session, event: Event) -> Event:
# Update the in-memory session. # Update the in-memory session.
super().append_event(session=session, event=event) await super().append_event(session=session, event=event)
reasoning_engine_id = _parse_reasoning_engine_id(session.app_name) reasoning_engine_id = _parse_reasoning_engine_id(session.app_name)
self.api_client.request( await self.api_client.async_request(
http_method='POST', http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent',
request_dict=_convert_event_to_json(event), request_dict=_convert_event_to_json(event),

View File

@ -129,7 +129,7 @@ class AgentTool(BaseTool):
session_service=InMemorySessionService(), session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(), memory_service=InMemoryMemoryService(),
) )
session = runner.session_service.create_session( session = await runner.session_service.create_session(
app_name=self.agent.name, app_name=self.agent.name,
user_id='tmp_user', user_id='tmp_user',
state=tool_context.state.to_dict(), state=tool_context.state.to_dict(),

View File

@ -110,11 +110,11 @@ class _TestingAgent(BaseAgent):
) )
def _create_parent_invocation_context( async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent, branch: Optional[str] = None test_name: str, agent: BaseAgent, branch: Optional[str] = None
) -> InvocationContext: ) -> InvocationContext:
session_service = InMemorySessionService() session_service = InMemorySessionService()
session = session_service.create_session( session = await session_service.create_session(
app_name='test_app', user_id='test_user' app_name='test_app', user_id='test_user'
) )
return InvocationContext( return InvocationContext(
@ -134,7 +134,7 @@ def test_invalid_agent_name():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_async(request: pytest.FixtureRequest): async def test_run_async(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent') agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
@ -148,7 +148,7 @@ async def test_run_async(request: pytest.FixtureRequest):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_async_with_branch(request: pytest.FixtureRequest): async def test_run_async_with_branch(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent') agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent, branch='parent_branch' request.function.__name__, agent, branch='parent_branch'
) )
@ -170,7 +170,7 @@ async def test_run_async_before_agent_callback_noop(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
before_agent_callback=_before_agent_callback_noop, before_agent_callback=_before_agent_callback_noop,
) )
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__) spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
@ -198,7 +198,7 @@ async def test_run_async_with_async_before_agent_callback_noop(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
before_agent_callback=_async_before_agent_callback_noop, before_agent_callback=_async_before_agent_callback_noop,
) )
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__) spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
@ -226,7 +226,7 @@ async def test_run_async_before_agent_callback_bypass_agent(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
before_agent_callback=_before_agent_callback_bypass_agent, before_agent_callback=_before_agent_callback_bypass_agent,
) )
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__) spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
@ -253,7 +253,7 @@ async def test_run_async_with_async_before_agent_callback_bypass_agent(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
before_agent_callback=_async_before_agent_callback_bypass_agent, before_agent_callback=_async_before_agent_callback_bypass_agent,
) )
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__) spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
@ -394,7 +394,7 @@ async def test_before_agent_callbacks_chain(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
before_agent_callback=[mock_cb for mock_cb in mock_cbs], before_agent_callback=[mock_cb for mock_cb in mock_cbs],
) )
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
result = [e async for e in agent.run_async(parent_ctx)] result = [e async for e in agent.run_async(parent_ctx)]
@ -455,7 +455,7 @@ async def test_after_agent_callbacks_chain(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
after_agent_callback=[mock_cb for mock_cb in mock_cbs], after_agent_callback=[mock_cb for mock_cb in mock_cbs],
) )
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
result = [e async for e in agent.run_async(parent_ctx)] result = [e async for e in agent.run_async(parent_ctx)]
@ -494,7 +494,7 @@ async def test_run_async_after_agent_callback_noop(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
after_agent_callback=_after_agent_callback_noop, after_agent_callback=_after_agent_callback_noop,
) )
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback') spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
@ -520,7 +520,7 @@ async def test_run_async_with_async_after_agent_callback_noop(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
after_agent_callback=_async_after_agent_callback_noop, after_agent_callback=_async_after_agent_callback_noop,
) )
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback') spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
@ -545,7 +545,7 @@ async def test_run_async_after_agent_callback_append_reply(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
after_agent_callback=_after_agent_callback_append_agent_reply, after_agent_callback=_after_agent_callback_append_agent_reply,
) )
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
@ -570,7 +570,7 @@ async def test_run_async_with_async_after_agent_callback_append_reply(
name=f'{request.function.__name__}_test_agent', name=f'{request.function.__name__}_test_agent',
after_agent_callback=_async_after_agent_callback_append_agent_reply, after_agent_callback=_async_after_agent_callback_append_agent_reply,
) )
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
@ -589,7 +589,7 @@ async def test_run_async_with_async_after_agent_callback_append_reply(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_async_incomplete_agent(request: pytest.FixtureRequest): async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent') agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
@ -600,7 +600,7 @@ async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_live(request: pytest.FixtureRequest): async def test_run_live(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent') agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )
@ -614,7 +614,7 @@ async def test_run_live(request: pytest.FixtureRequest):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_live_with_branch(request: pytest.FixtureRequest): async def test_run_live_with_branch(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent') agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent, branch='parent_branch' request.function.__name__, agent, branch='parent_branch'
) )
@ -629,7 +629,7 @@ async def test_run_live_with_branch(request: pytest.FixtureRequest):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_live_incomplete_agent(request: pytest.FixtureRequest): async def test_run_live_incomplete_agent(request: pytest.FixtureRequest):
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent') agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent request.function.__name__, agent
) )

View File

@ -15,7 +15,7 @@
"""Unit tests for canonical_xxx fields in LlmAgent.""" """Unit tests for canonical_xxx fields in LlmAgent."""
from typing import Any from typing import Any
from typing import Optional from typing import Optional, cast
from google.adk.agents.callback_context import CallbackContext from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.invocation_context import InvocationContext
@ -30,11 +30,11 @@ from pydantic import BaseModel
import pytest import pytest
def _create_readonly_context( async def _create_readonly_context(
agent: LlmAgent, state: Optional[dict[str, Any]] = None agent: LlmAgent, state: Optional[dict[str, Any]] = None
) -> ReadonlyContext: ) -> ReadonlyContext:
session_service = InMemorySessionService() session_service = InMemorySessionService()
session = session_service.create_session( session = await session_service.create_session(
app_name='test_app', user_id='test_user', state=state app_name='test_app', user_id='test_user', state=state
) )
invocation_context = InvocationContext( invocation_context = InvocationContext(
@ -77,7 +77,7 @@ def test_canonical_model_inherit():
async def test_canonical_instruction_str(): async def test_canonical_instruction_str():
agent = LlmAgent(name='test_agent', instruction='instruction') agent = LlmAgent(name='test_agent', instruction='instruction')
ctx = _create_readonly_context(agent) ctx = await _create_readonly_context(agent)
canonical_instruction = await agent.canonical_instruction(ctx) canonical_instruction = await agent.canonical_instruction(ctx)
assert canonical_instruction == 'instruction' assert canonical_instruction == 'instruction'
@ -88,7 +88,9 @@ async def test_canonical_instruction():
return f'instruction: {ctx.state["state_var"]}' return f'instruction: {ctx.state["state_var"]}'
agent = LlmAgent(name='test_agent', instruction=_instruction_provider) agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'}) ctx = await _create_readonly_context(
agent, state={'state_var': 'state_value'}
)
canonical_instruction = await agent.canonical_instruction(ctx) canonical_instruction = await agent.canonical_instruction(ctx)
assert canonical_instruction == 'instruction: state_value' assert canonical_instruction == 'instruction: state_value'
@ -99,7 +101,9 @@ async def test_async_canonical_instruction():
return f'instruction: {ctx.state["state_var"]}' return f'instruction: {ctx.state["state_var"]}'
agent = LlmAgent(name='test_agent', instruction=_instruction_provider) agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'}) ctx = await _create_readonly_context(
agent, state={'state_var': 'state_value'}
)
canonical_instruction = await agent.canonical_instruction(ctx) canonical_instruction = await agent.canonical_instruction(ctx)
assert canonical_instruction == 'instruction: state_value' assert canonical_instruction == 'instruction: state_value'
@ -107,10 +111,10 @@ async def test_async_canonical_instruction():
async def test_canonical_global_instruction_str(): async def test_canonical_global_instruction_str():
agent = LlmAgent(name='test_agent', global_instruction='global instruction') agent = LlmAgent(name='test_agent', global_instruction='global instruction')
ctx = _create_readonly_context(agent) ctx = await _create_readonly_context(agent)
canonical_global_instruction = await agent.canonical_global_instruction(ctx) canonical_instruction = await agent.canonical_global_instruction(ctx)
assert canonical_global_instruction == 'global instruction' assert canonical_instruction == 'global instruction'
async def test_canonical_global_instruction(): async def test_canonical_global_instruction():
@ -120,7 +124,9 @@ async def test_canonical_global_instruction():
agent = LlmAgent( agent = LlmAgent(
name='test_agent', global_instruction=_global_instruction_provider name='test_agent', global_instruction=_global_instruction_provider
) )
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'}) ctx = await _create_readonly_context(
agent, state={'state_var': 'state_value'}
)
canonical_global_instruction = await agent.canonical_global_instruction(ctx) canonical_global_instruction = await agent.canonical_global_instruction(ctx)
assert canonical_global_instruction == 'global instruction: state_value' assert canonical_global_instruction == 'global instruction: state_value'
@ -133,10 +139,14 @@ async def test_async_canonical_global_instruction():
agent = LlmAgent( agent = LlmAgent(
name='test_agent', global_instruction=_global_instruction_provider name='test_agent', global_instruction=_global_instruction_provider
) )
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'}) ctx = await _create_readonly_context(
agent, state={'state_var': 'state_value'}
)
canonical_global_instruction = await agent.canonical_global_instruction(ctx) assert (
assert canonical_global_instruction == 'global instruction: state_value' await agent.canonical_global_instruction(ctx)
== 'global instruction: state_value'
)
def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture): def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture):

View File

@ -70,11 +70,11 @@ class _TestingAgentWithEscalateAction(BaseAgent):
) )
def _create_parent_invocation_context( async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent test_name: str, agent: BaseAgent
) -> InvocationContext: ) -> InvocationContext:
session_service = InMemorySessionService() session_service = InMemorySessionService()
session = session_service.create_session( session = await session_service.create_session(
app_name='test_app', user_id='test_user' app_name='test_app', user_id='test_user'
) )
return InvocationContext( return InvocationContext(
@ -95,7 +95,7 @@ async def test_run_async(request: pytest.FixtureRequest):
agent, agent,
], ],
) )
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, loop_agent request.function.__name__, loop_agent
) )
events = [e async for e in loop_agent.run_async(parent_ctx)] events = [e async for e in loop_agent.run_async(parent_ctx)]
@ -119,7 +119,7 @@ async def test_run_async_with_escalate_action(request: pytest.FixtureRequest):
name=f'{request.function.__name__}_test_loop_agent', name=f'{request.function.__name__}_test_loop_agent',
sub_agents=[non_escalating_agent, escalating_agent], sub_agents=[non_escalating_agent, escalating_agent],
) )
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, loop_agent request.function.__name__, loop_agent
) )
events = [e async for e in loop_agent.run_async(parent_ctx)] events = [e async for e in loop_agent.run_async(parent_ctx)]

View File

@ -47,11 +47,11 @@ class _TestingAgent(BaseAgent):
) )
def _create_parent_invocation_context( async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent test_name: str, agent: BaseAgent
) -> InvocationContext: ) -> InvocationContext:
session_service = InMemorySessionService() session_service = InMemorySessionService()
session = session_service.create_session( session = await session_service.create_session(
app_name='test_app', user_id='test_user' app_name='test_app', user_id='test_user'
) )
return InvocationContext( return InvocationContext(
@ -76,7 +76,7 @@ async def test_run_async(request: pytest.FixtureRequest):
agent2, agent2,
], ],
) )
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, parallel_agent request.function.__name__, parallel_agent
) )
events = [e async for e in parallel_agent.run_async(parent_ctx)] events = [e async for e in parallel_agent.run_async(parent_ctx)]

View File

@ -53,11 +53,11 @@ class _TestingAgent(BaseAgent):
) )
def _create_parent_invocation_context( async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent test_name: str, agent: BaseAgent
) -> InvocationContext: ) -> InvocationContext:
session_service = InMemorySessionService() session_service = InMemorySessionService()
session = session_service.create_session( session = await session_service.create_session(
app_name='test_app', user_id='test_user' app_name='test_app', user_id='test_user'
) )
return InvocationContext( return InvocationContext(
@ -79,7 +79,7 @@ async def test_run_async(request: pytest.FixtureRequest):
agent_2, agent_2,
], ],
) )
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, sequential_agent request.function.__name__, sequential_agent
) )
events = [e async for e in sequential_agent.run_async(parent_ctx)] events = [e async for e in sequential_agent.run_async(parent_ctx)]
@ -102,7 +102,7 @@ async def test_run_live(request: pytest.FixtureRequest):
agent_2, agent_2,
], ],
) )
parent_ctx = _create_parent_invocation_context( parent_ctx = await _create_parent_invocation_context(
request.function.__name__, sequential_agent request.function.__name__, sequential_agent
) )
events = [e async for e in sequential_agent.run_live(parent_ctx)] events = [e async for e in sequential_agent.run_live(parent_ctx)]

View File

@ -192,65 +192,22 @@ async def test_run_cli_save_session(fake_agent, tmp_path: Path, monkeypatch: pyt
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_interactively_whitespace_and_exit(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: async def test_run_interactively_whitespace_and_exit(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""run_interactively should skip blank input, echo once, then exit.""" """run_interactively should skip blank input, echo once, then exit."""
# make a session that belongs to dummy agent # make a session that belongs to dummy agent
svc = cli.InMemorySessionService() svc = cli.InMemorySessionService()
sess = svc.create_session(app_name="dummy", user_id="u") sess = await svc.create_session(app_name="dummy", user_id="u")
artifact_service = cli.InMemoryArtifactService() artifact_service = cli.InMemoryArtifactService()
root_agent = types.SimpleNamespace(name="root") root_agent = types.SimpleNamespace(name="root")
# fake user input: blank -> 'hello' -> 'exit' # fake user input: blank -> 'hello' -> 'exit'
answers = iter([" ", "hello", "exit"]) answers = iter([" ", "hello", "exit"])
monkeypatch.setattr("builtins.input", lambda *_a, **_k: next(answers)) monkeypatch.setattr("builtins.input", lambda *_a, **_k: next(answers))
# capture assisted echo # capture assisted echo
echoed: list[str] = [] echoed: list[str] = []
monkeypatch.setattr(click, "echo", lambda msg: echoed.append(msg)) monkeypatch.setattr(click, "echo", lambda msg: echoed.append(msg))
await cli.run_interactively(root_agent, artifact_service, sess, svc) await cli.run_interactively(root_agent, artifact_service, sess, svc)
# verify: assistant echoed once with 'echo:hello' # verify: assistant echoed once with 'echo:hello'
assert any("echo:hello" in m for m in echoed) assert any("echo:hello" in m for m in echoed)
# run_cli (resume branch)
@pytest.mark.asyncio
async def test_run_cli_resume_saved_session(tmp_path: Path, fake_agent, monkeypatch: pytest.MonkeyPatch) -> None:
"""run_cli should load previous session, print its events, then re-enter interactive mode."""
parent_dir, folder = fake_agent
# stub Session.model_validate_json to return dummy session with two events
user_content = types.SimpleNamespace(parts=[types.SimpleNamespace(text="hi")])
assistant_content = types.SimpleNamespace(parts=[types.SimpleNamespace(text="hello!")])
dummy_session = types.SimpleNamespace(
id="sess",
app_name=folder,
user_id="u",
events=[
types.SimpleNamespace(author="user", content=user_content, partial=False),
types.SimpleNamespace(author="assistant", content=assistant_content, partial=False),
],
)
monkeypatch.setattr(cli.Session, "model_validate_json", staticmethod(lambda _s: dummy_session))
monkeypatch.setattr(cli.InMemorySessionService, "append_event", lambda *_a, **_k: None)
# interactive inputs: immediately 'exit'
monkeypatch.setattr("builtins.input", lambda *_a, **_k: "exit")
# collect echo output
captured: list[str] = []
monkeypatch.setattr(click, "echo", lambda m: captured.append(m))
saved_path = tmp_path / "prev.session.json"
saved_path.write_text("{}") # contents not used patched above
await cli.run_cli(
agent_parent_dir=str(parent_dir),
agent_folder_name=folder,
input_file=None,
saved_session_file=str(saved_path),
save_session=False,
)
# ④ ensure both historical messages were printed
assert any("[user]: hi" in m for m in captured)
assert any("[assistant]: hello!" in m for m in captured)

View File

@ -31,7 +31,7 @@ async def test_no_examples():
config=types.GenerateContentConfig(system_instruction=""), config=types.GenerateContentConfig(system_instruction=""),
) )
agent = Agent(model="gemini-1.5-flash", name="agent", examples=[]) agent = Agent(model="gemini-1.5-flash", name="agent", examples=[])
invocation_context = utils.create_invocation_context( invocation_context = await utils.create_invocation_context(
agent=agent, user_content="" agent=agent, user_content=""
) )
@ -69,7 +69,7 @@ async def test_agent_examples():
name="agent", name="agent",
examples=example_list, examples=example_list,
) )
invocation_context = utils.create_invocation_context( invocation_context = await utils.create_invocation_context(
agent=agent, user_content="test" agent=agent, user_content="test"
) )
@ -122,7 +122,7 @@ async def test_agent_base_example_provider():
name="agent", name="agent",
examples=provider, examples=provider,
) )
invocation_context = utils.create_invocation_context( invocation_context = await utils.create_invocation_context(
agent=agent, user_content="test" agent=agent, user_content="test"
) )

View File

@ -81,7 +81,7 @@ async def invoke_tool_with_callbacks(
before_tool_callback=before_cb, before_tool_callback=before_cb,
after_tool_callback=after_cb, after_tool_callback=after_cb,
) )
invocation_context = utils.create_invocation_context( invocation_context = await utils.create_invocation_context(
agent=agent, user_content="" agent=agent, user_content=""
) )
# Build function call event # Build function call event

View File

@ -28,7 +28,7 @@ async def test_no_description():
config=types.GenerateContentConfig(system_instruction=""), config=types.GenerateContentConfig(system_instruction=""),
) )
agent = Agent(model="gemini-1.5-flash", name="agent") agent = Agent(model="gemini-1.5-flash", name="agent")
invocation_context = utils.create_invocation_context(agent=agent) invocation_context = await utils.create_invocation_context(agent=agent)
async for _ in identity.request_processor.run_async( async for _ in identity.request_processor.run_async(
invocation_context, invocation_context,
@ -52,7 +52,7 @@ async def test_with_description():
name="agent", name="agent",
description="test description", description="test description",
) )
invocation_context = utils.create_invocation_context(agent=agent) invocation_context = await utils.create_invocation_context(agent=agent)
async for _ in identity.request_processor.run_async( async for _ in identity.request_processor.run_async(
invocation_context, invocation_context,

View File

@ -36,7 +36,7 @@ async def test_build_system_instruction():
{{customer_int }, { non-identifier-float}}, \ {{customer_int }, { non-identifier-float}}, \
{'key1': 'value1'} and {{'key2': 'value2'}}."""), {'key1': 'value1'} and {{'key2': 'value2'}}."""),
) )
invocation_context = utils.create_invocation_context(agent=agent) invocation_context = await utils.create_invocation_context(agent=agent)
invocation_context.session = Session( invocation_context.session = Session(
app_name="test_app", app_name="test_app",
user_id="test_user", user_id="test_user",
@ -73,7 +73,7 @@ async def test_function_system_instruction():
name="agent", name="agent",
instruction=build_function_instruction, instruction=build_function_instruction,
) )
invocation_context = utils.create_invocation_context(agent=agent) invocation_context = await utils.create_invocation_context(agent=agent)
invocation_context.session = Session( invocation_context.session = Session(
app_name="test_app", app_name="test_app",
user_id="test_user", user_id="test_user",
@ -111,7 +111,7 @@ async def test_async_function_system_instruction():
name="agent", name="agent",
instruction=build_function_instruction, instruction=build_function_instruction,
) )
invocation_context = utils.create_invocation_context(agent=agent) invocation_context = await utils.create_invocation_context(agent=agent)
invocation_context.session = Session( invocation_context.session = Session(
app_name="test_app", app_name="test_app",
user_id="test_user", user_id="test_user",
@ -147,7 +147,7 @@ async def test_global_system_instruction():
model="gemini-1.5-flash", model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""), config=types.GenerateContentConfig(system_instruction=""),
) )
invocation_context = utils.create_invocation_context(agent=sub_agent) invocation_context = await utils.create_invocation_context(agent=sub_agent)
invocation_context.session = Session( invocation_context.session = Session(
app_name="test_app", app_name="test_app",
user_id="test_user", user_id="test_user",
@ -189,7 +189,7 @@ async def test_function_global_system_instruction():
model="gemini-1.5-flash", model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""), config=types.GenerateContentConfig(system_instruction=""),
) )
invocation_context = utils.create_invocation_context(agent=sub_agent) invocation_context = await utils.create_invocation_context(agent=sub_agent)
invocation_context.session = Session( invocation_context.session = Session(
app_name="test_app", app_name="test_app",
user_id="test_user", user_id="test_user",
@ -231,7 +231,7 @@ async def test_async_function_global_system_instruction():
model="gemini-1.5-flash", model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""), config=types.GenerateContentConfig(system_instruction=""),
) )
invocation_context = utils.create_invocation_context(agent=sub_agent) invocation_context = await utils.create_invocation_context(agent=sub_agent)
invocation_context.session = Session( invocation_context.session = Session(
app_name="test_app", app_name="test_app",
user_id="test_user", user_id="test_user",
@ -263,7 +263,7 @@ async def test_build_system_instruction_with_namespace():
"""Use the echo_info tool to echo { customerId }, {app:key}, {user:key}, {a:key}.""" """Use the echo_info tool to echo { customerId }, {app:key}, {user:key}, {a:key}."""
), ),
) )
invocation_context = utils.create_invocation_context(agent=agent) invocation_context = await utils.create_invocation_context(agent=agent)
invocation_context.session = Session( invocation_context.session = Session(
app_name="test_app", app_name="test_app",
user_id="test_user", user_id="test_user",

View File

@ -37,26 +37,28 @@ def get_session_service(
return InMemorySessionService() return InMemorySessionService()
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
) )
def test_get_empty_session(service_type): async def test_get_empty_session(service_type):
session_service = get_session_service(service_type) session_service = get_session_service(service_type)
assert not session_service.get_session( assert not await session_service.get_session(
app_name='my_app', user_id='test_user', session_id='123' app_name='my_app', user_id='test_user', session_id='123'
) )
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
) )
def test_create_get_session(service_type): async def test_create_get_session(service_type):
session_service = get_session_service(service_type) session_service = get_session_service(service_type)
app_name = 'my_app' app_name = 'my_app'
user_id = 'test_user' user_id = 'test_user'
state = {'key': 'value'} state = {'key': 'value'}
session = session_service.create_session( session = await session_service.create_session(
app_name=app_name, user_id=user_id, state=state app_name=app_name, user_id=user_id, state=state
) )
assert session.app_name == app_name assert session.app_name == app_name
@ -64,50 +66,53 @@ def test_create_get_session(service_type):
assert session.id assert session.id
assert session.state == state assert session.state == state
assert ( assert (
session_service.get_session( await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id app_name=app_name, user_id=user_id, session_id=session.id
) )
== session == session
) )
session_id = session.id session_id = session.id
session_service.delete_session( await session_service.delete_session(
app_name=app_name, user_id=user_id, session_id=session_id app_name=app_name, user_id=user_id, session_id=session_id
) )
assert ( assert (
not session_service.get_session( await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id app_name=app_name, user_id=user_id, session_id=session.id
) )
== session != session
) )
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
) )
def test_create_and_list_sessions(service_type): async def test_create_and_list_sessions(service_type):
session_service = get_session_service(service_type) session_service = get_session_service(service_type)
app_name = 'my_app' app_name = 'my_app'
user_id = 'test_user' user_id = 'test_user'
session_ids = ['session' + str(i) for i in range(5)] session_ids = ['session' + str(i) for i in range(5)]
for session_id in session_ids: for session_id in session_ids:
session_service.create_session( await session_service.create_session(
app_name=app_name, user_id=user_id, session_id=session_id app_name=app_name, user_id=user_id, session_id=session_id
) )
sessions = session_service.list_sessions( list_sessions_response = await session_service.list_sessions(
app_name=app_name, user_id=user_id app_name=app_name, user_id=user_id
).sessions )
sessions = list_sessions_response.sessions
for i in range(len(sessions)): for i in range(len(sessions)):
assert sessions[i].id == session_ids[i] assert sessions[i].id == session_ids[i]
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
) )
def test_session_state(service_type): async def test_session_state(service_type):
session_service = get_session_service(service_type) session_service = get_session_service(service_type)
app_name = 'my_app' app_name = 'my_app'
user_id_1 = 'user1' user_id_1 = 'user1'
@ -118,19 +123,19 @@ def test_session_state(service_type):
state_11 = {'key11': 'value11'} state_11 = {'key11': 'value11'}
state_12 = {'key12': 'value12'} state_12 = {'key12': 'value12'}
session_11 = session_service.create_session( session_11 = await session_service.create_session(
app_name=app_name, app_name=app_name,
user_id=user_id_1, user_id=user_id_1,
state=state_11, state=state_11,
session_id=session_id_11, session_id=session_id_11,
) )
session_service.create_session( await session_service.create_session(
app_name=app_name, app_name=app_name,
user_id=user_id_1, user_id=user_id_1,
state=state_12, state=state_12,
session_id=session_id_12, session_id=session_id_12,
) )
session_service.create_session( await session_service.create_session(
app_name=app_name, user_id=user_id_2, session_id=session_id_2 app_name=app_name, user_id=user_id_2, session_id=session_id_2
) )
@ -149,7 +154,7 @@ def test_session_state(service_type):
} }
), ),
) )
session_service.append_event(session=session_11, event=event) await session_service.append_event(session=session_11, event=event)
# User and app state is stored, temp state is filtered. # User and app state is stored, temp state is filtered.
assert session_11.state.get('app:key') == 'value' assert session_11.state.get('app:key') == 'value'
@ -157,7 +162,7 @@ def test_session_state(service_type):
assert session_11.state.get('user:key1') == 'value1' assert session_11.state.get('user:key1') == 'value1'
assert not session_11.state.get('temp:key') assert not session_11.state.get('temp:key')
session_12 = session_service.get_session( session_12 = await session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session_id_12 app_name=app_name, user_id=user_id_1, session_id=session_id_12
) )
# After getting a new instance, the session_12 got the user and app state, # After getting a new instance, the session_12 got the user and app state,
@ -166,7 +171,7 @@ def test_session_state(service_type):
assert not session_12.state.get('temp:key') assert not session_12.state.get('temp:key')
# The user1's state is not visible to user2, app state is visible # The user1's state is not visible to user2, app state is visible
session_2 = session_service.get_session( session_2 = await session_service.get_session(
app_name=app_name, user_id=user_id_2, session_id=session_id_2 app_name=app_name, user_id=user_id_2, session_id=session_id_2
) )
assert session_2.state.get('app:key') == 'value' assert session_2.state.get('app:key') == 'value'
@ -175,7 +180,7 @@ def test_session_state(service_type):
assert not session_2.state.get('user:key1') assert not session_2.state.get('user:key1')
# The change to session_11 is persisted # The change to session_11 is persisted
session_11 = session_service.get_session( session_11 = await session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session_id_11 app_name=app_name, user_id=user_id_1, session_id=session_id_11
) )
assert session_11.state.get('key11') == 'value11_new' assert session_11.state.get('key11') == 'value11_new'
@ -183,10 +188,11 @@ def test_session_state(service_type):
assert not session_11.state.get('temp:key') assert not session_11.state.get('temp:key')
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
) )
def test_create_new_session_will_merge_states(service_type): async def test_create_new_session_will_merge_states(service_type):
session_service = get_session_service(service_type) session_service = get_session_service(service_type)
app_name = 'my_app' app_name = 'my_app'
user_id = 'user' user_id = 'user'
@ -194,7 +200,7 @@ def test_create_new_session_will_merge_states(service_type):
session_id_2 = 'session2' session_id_2 = 'session2'
state_1 = {'key1': 'value1'} state_1 = {'key1': 'value1'}
session_1 = session_service.create_session( session_1 = await session_service.create_session(
app_name=app_name, user_id=user_id, state=state_1, session_id=session_id_1 app_name=app_name, user_id=user_id, state=state_1, session_id=session_id_1
) )
@ -210,7 +216,7 @@ def test_create_new_session_will_merge_states(service_type):
} }
), ),
) )
session_service.append_event(session=session_1, event=event) await session_service.append_event(session=session_1, event=event)
# User and app state is stored, temp state is filtered. # User and app state is stored, temp state is filtered.
assert session_1.state.get('app:key') == 'value' assert session_1.state.get('app:key') == 'value'
@ -218,7 +224,7 @@ def test_create_new_session_will_merge_states(service_type):
assert session_1.state.get('user:key1') == 'value1' assert session_1.state.get('user:key1') == 'value1'
assert not session_1.state.get('temp:key') assert not session_1.state.get('temp:key')
session_2 = session_service.create_session( session_2 = await session_service.create_session(
app_name=app_name, user_id=user_id, state={}, session_id=session_id_2 app_name=app_name, user_id=user_id, state={}, session_id=session_id_2
) )
# Session 2 has the persisted states # Session 2 has the persisted states
@ -228,15 +234,18 @@ def test_create_new_session_will_merge_states(service_type):
assert not session_2.state.get('temp:key') assert not session_2.state.get('temp:key')
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
) )
def test_append_event_bytes(service_type): async def test_append_event_bytes(service_type):
session_service = get_session_service(service_type) session_service = get_session_service(service_type)
app_name = 'my_app' app_name = 'my_app'
user_id = 'user' user_id = 'user'
session = session_service.create_session(app_name=app_name, user_id=user_id) session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
event = Event( event = Event(
invocation_id='invocation', invocation_id='invocation',
author='user', author='user',
@ -249,30 +258,34 @@ def test_append_event_bytes(service_type):
], ],
), ),
) )
session_service.append_event(session=session, event=event) await session_service.append_event(session=session, event=event)
assert session.events[0].content.parts[0] == types.Part.from_bytes( assert session.events[0].content.parts[0] == types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png' data=b'test_image_data', mime_type='image/png'
) )
events = session_service.get_session( session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id app_name=app_name, user_id=user_id, session_id=session.id
).events )
events = session.events
assert len(events) == 1 assert len(events) == 1
assert events[0].content.parts[0] == types.Part.from_bytes( assert events[0].content.parts[0] == types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png' data=b'test_image_data', mime_type='image/png'
) )
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
) )
def test_append_event_complete(service_type): async def test_append_event_complete(service_type):
session_service = get_session_service(service_type) session_service = get_session_service(service_type)
app_name = 'my_app' app_name = 'my_app'
user_id = 'user' user_id = 'user'
session = session_service.create_session(app_name=app_name, user_id=user_id) session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
event = Event( event = Event(
invocation_id='invocation', invocation_id='invocation',
author='user', author='user',
@ -291,65 +304,73 @@ def test_append_event_complete(service_type):
error_message='error_message', error_message='error_message',
interrupted=True, interrupted=True,
) )
session_service.append_event(session=session, event=event) await session_service.append_event(session=session, event=event)
assert ( assert (
session_service.get_session( await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id app_name=app_name, user_id=user_id, session_id=session.id
) )
== session == session
) )
@pytest.mark.asyncio
@pytest.mark.parametrize('service_type', [SessionServiceType.IN_MEMORY]) @pytest.mark.parametrize('service_type', [SessionServiceType.IN_MEMORY])
def test_get_session_with_config(service_type): async def test_get_session_with_config(service_type):
session_service = get_session_service(service_type) session_service = get_session_service(service_type)
app_name = 'my_app' app_name = 'my_app'
user_id = 'user' user_id = 'user'
num_test_events = 5 num_test_events = 5
session = session_service.create_session(app_name=app_name, user_id=user_id) session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
for i in range(1, num_test_events + 1): for i in range(1, num_test_events + 1):
event = Event(author='user', timestamp=i) event = Event(author='user', timestamp=i)
session_service.append_event(session, event) await session_service.append_event(session, event)
# No config, expect all events to be returned. # No config, expect all events to be returned.
events = session_service.get_session( session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id app_name=app_name, user_id=user_id, session_id=session.id
).events )
events = session.events
assert len(events) == num_test_events assert len(events) == num_test_events
# Only expect the most recent 3 events. # Only expect the most recent 3 events.
num_recent_events = 3 num_recent_events = 3
config = GetSessionConfig(num_recent_events=num_recent_events) config = GetSessionConfig(num_recent_events=num_recent_events)
events = session_service.get_session( session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config app_name=app_name, user_id=user_id, session_id=session.id, config=config
).events )
events = session.events
assert len(events) == num_recent_events assert len(events) == num_recent_events
assert events[0].timestamp == num_test_events - num_recent_events + 1 assert events[0].timestamp == num_test_events - num_recent_events + 1
# Only expect events after timestamp 4.0 (inclusive), i.e., 2 events. # Only expect events after timestamp 4.0 (inclusive), i.e., 2 events.
after_timestamp = 4.0 after_timestamp = 4.0
config = GetSessionConfig(after_timestamp=after_timestamp) config = GetSessionConfig(after_timestamp=after_timestamp)
events = session_service.get_session( session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config app_name=app_name, user_id=user_id, session_id=session.id, config=config
).events )
events = session.events
assert len(events) == num_test_events - after_timestamp + 1 assert len(events) == num_test_events - after_timestamp + 1
assert events[0].timestamp == after_timestamp assert events[0].timestamp == after_timestamp
# Expect no events if none are > after_timestamp. # Expect no events if none are > after_timestamp.
way_after_timestamp = num_test_events * 10 way_after_timestamp = num_test_events * 10
config = GetSessionConfig(after_timestamp=way_after_timestamp) config = GetSessionConfig(after_timestamp=way_after_timestamp)
events = session_service.get_session( session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config app_name=app_name, user_id=user_id, session_id=session.id, config=config
).events )
assert len(events) == 0 assert not session.events
# Both filters applied, i.e., of 3 most recent events, only 2 are after # Both filters applied, i.e., of 3 most recent events, only 2 are after
# timestamp 4.0, so expect 2 events. # timestamp 4.0, so expect 2 events.
config = GetSessionConfig( config = GetSessionConfig(
after_timestamp=after_timestamp, num_recent_events=num_recent_events after_timestamp=after_timestamp, num_recent_events=num_recent_events
) )
events = session_service.get_session( session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config app_name=app_name, user_id=user_id, session_id=session.id, config=config
).events )
events = session.events
assert len(events) == num_test_events - after_timestamp + 1 assert len(events) == num_test_events - after_timestamp + 1

View File

@ -15,7 +15,7 @@
import re import re
import this import this
from typing import Any from typing import Any
import uuid
from dateutil.parser import isoparse from dateutil.parser import isoparse
from google.adk.events import Event from google.adk.events import Event
from google.adk.events import EventActions from google.adk.events import EventActions
@ -124,7 +124,9 @@ class MockApiClient:
this.session_dict: dict[str, Any] = {} this.session_dict: dict[str, Any] = {}
this.event_dict: dict[str, list[Any]] = {} this.event_dict: dict[str, list[Any]] = {}
def request(self, http_method: str, path: str, request_dict: dict[str, Any]): async def async_request(
self, http_method: str, path: str, request_dict: dict[str, Any]
):
"""Mocks the API Client request method.""" """Mocks the API Client request method."""
if http_method == 'GET': if http_method == 'GET':
if re.match(SESSION_REGEX, path): if re.match(SESSION_REGEX, path):
@ -210,46 +212,52 @@ def mock_vertex_ai_session_service():
return service return service
def test_get_empty_session(): @pytest.mark.asyncio
async def test_get_empty_session():
session_service = mock_vertex_ai_session_service() session_service = mock_vertex_ai_session_service()
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
assert session_service.get_session( assert await session_service.get_session(
app_name='123', user_id='user', session_id='0' app_name='123', user_id='user', session_id='0'
) )
assert str(excinfo.value) == 'Session not found: 0' assert str(excinfo.value) == 'Session not found: 0'
def test_get_and_delete_session(): @pytest.mark.asyncio
async def test_get_and_delete_session():
session_service = mock_vertex_ai_session_service() session_service = mock_vertex_ai_session_service()
assert ( assert (
session_service.get_session( await session_service.get_session(
app_name='123', user_id='user', session_id='1' app_name='123', user_id='user', session_id='1'
) )
== MOCK_SESSION == MOCK_SESSION
) )
session_service.delete_session(app_name='123', user_id='user', session_id='1') await session_service.delete_session(
app_name='123', user_id='user', session_id='1'
)
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
assert session_service.get_session( assert await session_service.get_session(
app_name='123', user_id='user', session_id='1' app_name='123', user_id='user', session_id='1'
) )
assert str(excinfo.value) == 'Session not found: 1' assert str(excinfo.value) == 'Session not found: 1'
def test_list_sessions(): @pytest.mark.asyncio
async def test_list_sessions():
session_service = mock_vertex_ai_session_service() session_service = mock_vertex_ai_session_service()
sessions = session_service.list_sessions(app_name='123', user_id='user') sessions = await session_service.list_sessions(app_name='123', user_id='user')
assert len(sessions.sessions) == 2 assert len(sessions.sessions) == 2
assert sessions.sessions[0].id == '1' assert sessions.sessions[0].id == '1'
assert sessions.sessions[1].id == '2' assert sessions.sessions[1].id == '2'
def test_create_session(): @pytest.mark.asyncio
async def test_create_session():
session_service = mock_vertex_ai_session_service() session_service = mock_vertex_ai_session_service()
state = {'key': 'value'} state = {'key': 'value'}
session = session_service.create_session( session = await session_service.create_session(
app_name='123', user_id='user', state=state app_name='123', user_id='user', state=state
) )
assert session.state == state assert session.state == state
@ -258,16 +266,17 @@ def test_create_session():
assert session.last_update_time is not None assert session.last_update_time is not None
session_id = session.id session_id = session.id
assert session == session_service.get_session( assert session == await session_service.get_session(
app_name='123', user_id='user', session_id=session_id app_name='123', user_id='user', session_id=session_id
) )
def test_create_session_with_custom_session_id(): @pytest.mark.asyncio
async def test_create_session_with_custom_session_id():
session_service = mock_vertex_ai_session_service() session_service = mock_vertex_ai_session_service()
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
session_service.create_session( await session_service.create_session(
app_name='123', user_id='user', session_id='1' app_name='123', user_id='user', session_id='1'
) )
assert str(excinfo.value) == ( assert str(excinfo.value) == (

View File

@ -37,9 +37,9 @@ class _TestingTool(BaseTool):
return self.declaration return self.declaration
def _create_tool_context() -> ToolContext: async def _create_tool_context() -> ToolContext:
session_service = InMemorySessionService() session_service = InMemorySessionService()
session = session_service.create_session( session = await session_service.create_session(
app_name='test_app', user_id='test_user' app_name='test_app', user_id='test_user'
) )
agent = SequentialAgent(name='test_agent') agent = SequentialAgent(name='test_agent')
@ -55,7 +55,7 @@ def _create_tool_context() -> ToolContext:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_llm_request_no_declaration(): async def test_process_llm_request_no_declaration():
tool = _TestingTool() tool = _TestingTool()
tool_context = _create_tool_context() tool_context = await _create_tool_context()
llm_request = LlmRequest() llm_request = LlmRequest()
await tool.process_llm_request( await tool.process_llm_request(
@ -77,7 +77,7 @@ async def test_process_llm_request_with_declaration():
) )
tool = _TestingTool(declaration) tool = _TestingTool(declaration)
llm_request = LlmRequest() llm_request = LlmRequest()
tool_context = _create_tool_context() tool_context = await _create_tool_context()
await tool.process_llm_request( await tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request tool_context=tool_context, llm_request=llm_request
@ -102,7 +102,7 @@ async def test_process_llm_request_with_builtin_tool():
tools=[types.Tool(google_search=types.GoogleSearch())] tools=[types.Tool(google_search=types.GoogleSearch())]
) )
) )
tool_context = _create_tool_context() tool_context = await _create_tool_context()
await tool.process_llm_request( await tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request tool_context=tool_context, llm_request=llm_request
@ -131,7 +131,7 @@ async def test_process_llm_request_with_builtin_tool_and_another_declaration():
] ]
) )
) )
tool_context = _create_tool_context() tool_context = await _create_tool_context()
await tool.process_llm_request( await tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request tool_context=tool_context, llm_request=llm_request

View File

@ -56,7 +56,7 @@ class ModelContent(types.Content):
super().__init__(role='model', parts=parts) super().__init__(role='model', parts=parts)
def create_invocation_context(agent: Agent, user_content: str = ''): async def create_invocation_context(agent: Agent, user_content: str = ''):
invocation_id = 'test_id' invocation_id = 'test_id'
artifact_service = InMemoryArtifactService() artifact_service = InMemoryArtifactService()
session_service = InMemorySessionService() session_service = InMemorySessionService()
@ -67,7 +67,7 @@ def create_invocation_context(agent: Agent, user_content: str = ''):
memory_service=memory_service, memory_service=memory_service,
invocation_id=invocation_id, invocation_id=invocation_id,
agent=agent, agent=agent,
session=session_service.create_session( session=await session_service.create_session(
app_name='test_app', user_id='test_user' app_name='test_app', user_id='test_user'
), ),
user_content=types.Content( user_content=types.Content(
@ -141,7 +141,7 @@ class TestInMemoryRunner(AfInMemoryRunner):
self, new_message: types.ContentUnion self, new_message: types.ContentUnion
) -> list[Event]: ) -> list[Event]:
session = self.session_service.create_session( session = await self.session_service.create_session(
app_name='InMemoryRunner', user_id='test_user' app_name='InMemoryRunner', user_id='test_user'
) )
collected_events = [] collected_events = []
@ -172,14 +172,22 @@ class InMemoryRunner:
session_service=InMemorySessionService(), session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(), memory_service=InMemoryMemoryService(),
) )
self.session_id = self.runner.session_service.create_session( self.session_id = None
app_name='test_app', user_id='test_user'
).id
@property @property
def session(self) -> Session: def session(self) -> Session:
return self.runner.session_service.get_session( if not self.session_id:
app_name='test_app', user_id='test_user', session_id=self.session_id session = asyncio.run(
self.runner.session_service.create_session(
app_name='test_app', user_id='test_user'
)
)
self.session_id = session.id
return session
return asyncio.run(
self.runner.session_service.get_session(
app_name='test_app', user_id='test_user', session_id=self.session_id
)
) )
def run(self, new_message: types.ContentUnion) -> list[Event]: def run(self, new_message: types.ContentUnion) -> list[Event]:
@ -194,9 +202,9 @@ class InMemoryRunner:
def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]: def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]:
collected_responses = [] collected_responses = []
async def consume_responses(): async def consume_responses(session: Session):
run_res = self.runner.run_live( run_res = self.runner.run_live(
session=self.session, session=session,
live_request_queue=live_request_queue, live_request_queue=live_request_queue,
) )
@ -207,7 +215,8 @@ class InMemoryRunner:
return return
try: try:
asyncio.run(consume_responses()) session = self.session
asyncio.run(consume_responses(session))
except asyncio.TimeoutError: except asyncio.TimeoutError:
print('Returning any partial results collected so far.') print('Returning any partial results collected so far.')