From 5b3204c356c4a13a661038cc2a77ebf336d0a112 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Thu, 15 May 2025 11:16:43 -0700 Subject: [PATCH] feat! Update session service interface to be async. Also keep the sync version in the InMemorySessionService as create_session_sync() as a temporary migration option. PiperOrigin-RevId: 759224250 --- src/google/adk/cli/cli.py | 10 +- src/google/adk/cli/fast_api.py | 41 ++++--- .../adk/evaluation/evaluation_generator.py | 2 +- src/google/adk/runners.py | 10 +- .../adk/sessions/base_session_service.py | 17 +-- .../adk/sessions/database_session_service.py | 13 +- .../adk/sessions/in_memory_session_service.py | 21 ++-- .../adk/sessions/vertex_ai_session_service.py | 26 ++-- src/google/adk/tools/agent_tool.py | 2 +- tests/unittests/agents/test_base_agent.py | 36 +++--- .../unittests/agents/test_llm_agent_fields.py | 36 ++++-- tests/unittests/agents/test_loop_agent.py | 8 +- tests/unittests/agents/test_parallel_agent.py | 6 +- .../unittests/agents/test_sequential_agent.py | 8 +- tests/unittests/cli/utils/test_cli.py | 73 +++-------- .../flows/llm_flows/_test_examples.py | 6 +- .../llm_flows/test_async_tool_callbacks.py | 2 +- .../flows/llm_flows/test_identity.py | 4 +- .../flows/llm_flows/test_instructions.py | 14 +-- .../sessions/test_session_service.py | 115 +++++++++++------- .../test_vertex_ai_session_service.py | 39 +++--- tests/unittests/tools/test_base_tool.py | 12 +- tests/unittests/utils.py | 31 +++-- 23 files changed, 264 insertions(+), 268 deletions(-) diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index 8f466ad..6d267de 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -55,7 +55,7 @@ async def run_input_file( input_file = InputFile.model_validate_json(f.read()) input_file.state['_time'] = datetime.now() - session = session_service.create_session( + session = await session_service.create_session( app_name=app_name, user_id=user_id, state=input_file.state ) for query in input_file.queries: @@ -130,7 +130,7 @@ async def run_cli( agent_module_path = os.path.join(agent_parent_dir, agent_folder_name) agent_module = importlib.import_module(agent_folder_name) user_id = 'test_user' - session = session_service.create_session( + session = await session_service.create_session( app_name=agent_folder_name, user_id=user_id ) root_agent = agent_module.agent.root_agent @@ -145,14 +145,12 @@ async def run_cli( input_path=input_file, ) elif saved_session_file: - - loaded_session = None with open(saved_session_file, 'r') as f: loaded_session = Session.model_validate_json(f.read()) if loaded_session: for event in loaded_session.events: - session_service.append_event(session, event) + await session_service.append_event(session, event) content = event.content if not content or not content.parts or not content.parts[0].text: continue @@ -181,7 +179,7 @@ async def run_cli( session_path = f'{agent_module_path}/{session_id}.session.json' # Fetch the session again to get all the details. - session = session_service.get_session( + session = await session_service.get_session( app_name=session.app_name, user_id=session.user_id, session_id=session.id, diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index ea49143..058a1ac 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -333,10 +333,12 @@ def get_fast_api_app( "/apps/{app_name}/users/{user_id}/sessions/{session_id}", response_model_exclude_none=True, ) - def get_session(app_name: str, user_id: str, session_id: str) -> Session: + async def get_session( + app_name: str, user_id: str, session_id: str + ) -> Session: # Connect to managed session if agent_engine_id is set. app_name = agent_engine_id if agent_engine_id else app_name - session = session_service.get_session( + session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session_id ) if not session: @@ -347,14 +349,15 @@ def get_fast_api_app( "/apps/{app_name}/users/{user_id}/sessions", response_model_exclude_none=True, ) - def list_sessions(app_name: str, user_id: str) -> list[Session]: + async def list_sessions(app_name: str, user_id: str) -> list[Session]: # Connect to managed session if agent_engine_id is set. app_name = agent_engine_id if agent_engine_id else app_name + list_sessions_response = await session_service.list_sessions( + app_name=app_name, user_id=user_id + ) return [ session - for session in session_service.list_sessions( - app_name=app_name, user_id=user_id - ).sessions + for session in list_sessions_response.sessions # Remove sessions that were generated as a part of Eval. if not session.id.startswith(EVAL_SESSION_ID_PREFIX) ] @@ -363,7 +366,7 @@ def get_fast_api_app( "/apps/{app_name}/users/{user_id}/sessions/{session_id}", response_model_exclude_none=True, ) - def create_session_with_id( + async def create_session_with_id( app_name: str, user_id: str, session_id: str, @@ -372,7 +375,7 @@ def get_fast_api_app( # Connect to managed session if agent_engine_id is set. app_name = agent_engine_id if agent_engine_id else app_name if ( - session_service.get_session( + await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session_id ) is not None @@ -382,7 +385,7 @@ def get_fast_api_app( status_code=400, detail=f"Session already exists: {session_id}" ) logger.info("New session created: %s", session_id) - return session_service.create_session( + return await session_service.create_session( app_name=app_name, user_id=user_id, state=state, session_id=session_id ) @@ -390,7 +393,7 @@ def get_fast_api_app( "/apps/{app_name}/users/{user_id}/sessions", response_model_exclude_none=True, ) - def create_session( + async def create_session( app_name: str, user_id: str, state: Optional[dict[str, Any]] = None, @@ -398,7 +401,7 @@ def get_fast_api_app( # Connect to managed session if agent_engine_id is set. app_name = agent_engine_id if agent_engine_id else app_name logger.info("New session created") - return session_service.create_session( + return await session_service.create_session( app_name=app_name, user_id=user_id, state=state ) @@ -442,7 +445,7 @@ def get_fast_api_app( app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest ): # Get the session - session = session_service.get_session( + session = await session_service.get_session( app_name=app_name, user_id=req.user_id, session_id=req.session_id ) assert session, "Session not found." @@ -530,7 +533,7 @@ def get_fast_api_app( session_id=eval_case_result.session_id, ) ) - eval_case_result.session_details = session_service.get_session( + eval_case_result.session_details = await session_service.get_session( app_name=app_name, user_id=eval_case_result.user_id, session_id=eval_case_result.session_id, @@ -615,10 +618,10 @@ def get_fast_api_app( return eval_result_files @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}") - def delete_session(app_name: str, user_id: str, session_id: str): + async def delete_session(app_name: str, user_id: str, session_id: str): # Connect to managed session if agent_engine_id is set. app_name = agent_engine_id if agent_engine_id else app_name - session_service.delete_session( + await session_service.delete_session( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -713,7 +716,7 @@ def get_fast_api_app( async def agent_run(req: AgentRunRequest) -> list[Event]: # Connect to managed session if agent_engine_id is set. app_id = agent_engine_id if agent_engine_id else req.app_name - session = session_service.get_session( + session = await session_service.get_session( app_name=app_id, user_id=req.user_id, session_id=req.session_id ) if not session: @@ -735,7 +738,7 @@ def get_fast_api_app( # Connect to managed session if agent_engine_id is set. app_id = agent_engine_id if agent_engine_id else req.app_name # SSE endpoint - session = session_service.get_session( + session = await session_service.get_session( app_name=app_id, user_id=req.user_id, session_id=req.session_id ) if not session: @@ -776,7 +779,7 @@ def get_fast_api_app( ): # Connect to managed session if agent_engine_id is set. app_id = agent_engine_id if agent_engine_id else app_name - session = session_service.get_session( + session = await session_service.get_session( app_name=app_id, user_id=user_id, session_id=session_id ) session_events = session.events if session else [] @@ -833,7 +836,7 @@ def get_fast_api_app( # Connect to managed session if agent_engine_id is set. app_id = agent_engine_id if agent_engine_id else app_name - session = session_service.get_session( + session = await session_service.get_session( app_name=app_id, user_id=user_id, session_id=session_id ) if not session: diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index c59868e..f07b3f8 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -126,7 +126,7 @@ class EvaluationGenerator: user_id = initial_session.user_id if initial_session else "test_user_id" session_id = session_id if session_id else str(uuid.uuid4()) - _ = session_service.create_session( + _ = await session_service.create_session( app_name=app_name, user_id=user_id, state=initial_session.state if initial_session else {}, diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index d8cdd63..e56b79b 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -173,7 +173,7 @@ class Runner: The events generated by the agent. """ with tracer.start_as_current_span('invocation'): - session = self.session_service.get_session( + session = await self.session_service.get_session( app_name=self.app_name, user_id=user_id, session_id=session_id ) if not session: @@ -197,7 +197,7 @@ class Runner: invocation_context.agent = self._find_agent_to_run(session, root_agent) async for event in invocation_context.agent.run_async(invocation_context): if not event.partial: - self.session_service.append_event(session=session, event=event) + await self.session_service.append_event(session=session, event=event) yield event async def _append_new_message_to_session( @@ -242,7 +242,7 @@ class Runner: author='user', content=new_message, ) - self.session_service.append_event(session=session, event=event) + await self.session_service.append_event(session=session, event=event) async def run_live( self, @@ -324,7 +324,7 @@ class Runner: ) async for event in invocation_context.agent.run_live(invocation_context): - self.session_service.append_event(session=session, event=event) + await self.session_service.append_event(session=session, event=event) yield event async def close_session(self, session: Session): @@ -335,7 +335,7 @@ class Runner: """ if self.memory_service: await self.memory_service.add_session_to_memory(session) - self.session_service.close_session(session=session) + await self.session_service.close_session(session=session) def _find_agent_to_run( self, session: Session, root_agent: BaseAgent diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index 82dcd99..6a98531 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -47,7 +47,7 @@ class BaseSessionService(abc.ABC): """ @abc.abstractmethod - def create_session( + async def create_session( self, *, app_name: str, @@ -67,10 +67,9 @@ class BaseSessionService(abc.ABC): Returns: session: The newly created session instance. """ - pass @abc.abstractmethod - def get_session( + async def get_session( self, *, app_name: str, @@ -79,28 +78,24 @@ class BaseSessionService(abc.ABC): config: Optional[GetSessionConfig] = None, ) -> Optional[Session]: """Gets a session.""" - pass @abc.abstractmethod - def list_sessions( + async def list_sessions( self, *, app_name: str, user_id: str ) -> ListSessionsResponse: """Lists all the sessions.""" - pass @abc.abstractmethod - def delete_session( + async def delete_session( self, *, app_name: str, user_id: str, session_id: str ) -> None: """Deletes a session.""" - pass - def close_session(self, *, session: Session): + async def close_session(self, *, session: Session): """Closes a session.""" # TODO: determine whether we want to finalize the session here. - pass - def append_event(self, session: Session, event: Event) -> Event: + async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session object.""" if event.partial: return event diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index f9d73ae..b1d7028 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -283,7 +283,7 @@ class DatabaseSessionService(BaseSessionService): Base.metadata.create_all(self.db_engine) @override - def create_session( + async def create_session( self, *, app_name: str, @@ -357,7 +357,7 @@ class DatabaseSessionService(BaseSessionService): return session @override - def get_session( + async def get_session( self, *, app_name: str, @@ -431,7 +431,7 @@ class DatabaseSessionService(BaseSessionService): return session @override - def list_sessions( + async def list_sessions( self, *, app_name: str, user_id: str ) -> ListSessionsResponse: with self.DatabaseSessionFactory() as sessionFactory: @@ -454,7 +454,7 @@ class DatabaseSessionService(BaseSessionService): return ListSessionsResponse(sessions=sessions) @override - def delete_session( + async def delete_session( self, app_name: str, user_id: str, session_id: str ) -> None: with self.DatabaseSessionFactory() as sessionFactory: @@ -467,7 +467,7 @@ class DatabaseSessionService(BaseSessionService): sessionFactory.commit() @override - def append_event(self, session: Session, event: Event) -> Event: + async def append_event(self, session: Session, event: Event) -> Event: logger.info(f"Append event: {event} to session {session.id}") if event.partial: @@ -552,9 +552,10 @@ class DatabaseSessionService(BaseSessionService): session.last_update_time = storage_session.update_time.timestamp() # Also update the in-memory session - super().append_event(session=session, event=event) + await super().append_event(session=session, event=event) return event + def convert_event(event: StorageEvent) -> Event: """Converts a storage event to an event.""" return Event( diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index 0d79420..6f2b4cd 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -44,7 +44,7 @@ class InMemorySessionService(BaseSessionService): self.app_state: dict[str, dict[str, Any]] = {} @override - def create_session( + async def create_session( self, *, app_name: str, @@ -106,7 +106,7 @@ class InMemorySessionService(BaseSessionService): return self._merge_state(app_name, user_id, copied_session) @override - def get_session( + async def get_session( self, *, app_name: str, @@ -193,7 +193,7 @@ class InMemorySessionService(BaseSessionService): return copied_session @override - def list_sessions( + async def list_sessions( self, *, app_name: str, user_id: str ) -> ListSessionsResponse: return self._list_sessions_impl(app_name=app_name, user_id=user_id) @@ -221,7 +221,7 @@ class InMemorySessionService(BaseSessionService): sessions_without_events.append(copied_session) return ListSessionsResponse(sessions=sessions_without_events) - def delete_session( + async def delete_session( self, *, app_name: str, user_id: str, session_id: str ) -> None: self._delete_session_impl( @@ -250,16 +250,9 @@ class InMemorySessionService(BaseSessionService): self.sessions[app_name][user_id].pop(session_id) @override - def append_event(self, session: Session, event: Event) -> Event: - return self._append_event_impl(session=session, event=event) - - def append_event_sync(self, session: Session, event: Event) -> Event: - logger.warning('Deprecated. Please migrate to the async method.') - return self._append_event_impl(session=session, event=event) - - def _append_event_impl(self, session: Session, event: Event) -> Event: + async def append_event(self, session: Session, event: Event) -> Event: # Update the in-memory session. - super().append_event(session=session, event=event) + await super().append_event(session=session, event=event) session.last_update_time = event.timestamp # Update the storage session @@ -286,7 +279,7 @@ class InMemorySessionService(BaseSessionService): ] = event.actions.state_delta[key] storage_session = self.sessions[app_name][user_id].get(session_id) - super().append_event(session=storage_session, event=event) + await super().append_event(session=storage_session, event=event) storage_session.last_update_time = event.timestamp diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index c49c43a..cb9831e 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -48,7 +48,7 @@ class VertexAiSessionService(BaseSessionService): self.api_client = client._api_client @override - def create_session( + async def create_session( self, *, app_name: str, @@ -68,7 +68,7 @@ class VertexAiSessionService(BaseSessionService): if state: session_json_dict['session_state'] = state - api_response = self.api_client.request( + api_response = await self.api_client.async_request( http_method='POST', path=f'reasoningEngines/{reasoning_engine_id}/sessions', request_dict=session_json_dict, @@ -80,7 +80,7 @@ class VertexAiSessionService(BaseSessionService): max_retry_attempt = 5 while max_retry_attempt >= 0: - lro_response = self.api_client.request( + lro_response = await self.api_client.async_request( http_method='GET', path=f'operations/{operation_id}', request_dict={}, @@ -93,7 +93,7 @@ class VertexAiSessionService(BaseSessionService): max_retry_attempt -= 1 # Get session resource - get_session_api_response = self.api_client.request( + get_session_api_response = await self.api_client.async_request( http_method='GET', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, @@ -112,7 +112,7 @@ class VertexAiSessionService(BaseSessionService): return session @override - def get_session( + async def get_session( self, *, app_name: str, @@ -123,7 +123,7 @@ class VertexAiSessionService(BaseSessionService): reasoning_engine_id = _parse_reasoning_engine_id(app_name) # Get session resource - get_session_api_response = self.api_client.request( + get_session_api_response = await self.api_client.async_request( http_method='GET', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, @@ -141,7 +141,7 @@ class VertexAiSessionService(BaseSessionService): last_update_time=update_timestamp, ) - list_events_api_response = self.api_client.request( + list_events_api_response = await self.api_client.async_request( http_method='GET', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events', request_dict={}, @@ -175,7 +175,7 @@ class VertexAiSessionService(BaseSessionService): return session @override - def list_sessions( + async def list_sessions( self, *, app_name: str, user_id: str ) -> ListSessionsResponse: reasoning_engine_id = _parse_reasoning_engine_id(app_name) @@ -202,23 +202,23 @@ class VertexAiSessionService(BaseSessionService): sessions.append(session) return ListSessionsResponse(sessions=sessions) - def delete_session( + async def delete_session( self, *, app_name: str, user_id: str, session_id: str ) -> None: reasoning_engine_id = _parse_reasoning_engine_id(app_name) - self.api_client.request( + await self.api_client.async_request( http_method='DELETE', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, ) @override - def append_event(self, session: Session, event: Event) -> Event: + async def append_event(self, session: Session, event: Event) -> Event: # Update the in-memory session. - super().append_event(session=session, event=event) + await super().append_event(session=session, event=event) reasoning_engine_id = _parse_reasoning_engine_id(session.app_name) - self.api_client.request( + await self.api_client.async_request( http_method='POST', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent', request_dict=_convert_event_to_json(event), diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 7f62829..84e6b09 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -129,7 +129,7 @@ class AgentTool(BaseTool): session_service=InMemorySessionService(), memory_service=InMemoryMemoryService(), ) - session = runner.session_service.create_session( + session = await runner.session_service.create_session( app_name=self.agent.name, user_id='tmp_user', state=tool_context.state.to_dict(), diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index e162440..3378143 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -110,11 +110,11 @@ class _TestingAgent(BaseAgent): ) -def _create_parent_invocation_context( +async def _create_parent_invocation_context( test_name: str, agent: BaseAgent, branch: Optional[str] = None ) -> InvocationContext: session_service = InMemorySessionService() - session = session_service.create_session( + session = await session_service.create_session( app_name='test_app', user_id='test_user' ) return InvocationContext( @@ -134,7 +134,7 @@ def test_invalid_agent_name(): @pytest.mark.asyncio async def test_run_async(request: pytest.FixtureRequest): agent = _TestingAgent(name=f'{request.function.__name__}_test_agent') - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, agent ) @@ -148,7 +148,7 @@ async def test_run_async(request: pytest.FixtureRequest): @pytest.mark.asyncio async def test_run_async_with_branch(request: pytest.FixtureRequest): agent = _TestingAgent(name=f'{request.function.__name__}_test_agent') - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, agent, branch='parent_branch' ) @@ -170,7 +170,7 @@ async def test_run_async_before_agent_callback_noop( name=f'{request.function.__name__}_test_agent', before_agent_callback=_before_agent_callback_noop, ) - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, agent ) spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__) @@ -198,7 +198,7 @@ async def test_run_async_with_async_before_agent_callback_noop( name=f'{request.function.__name__}_test_agent', before_agent_callback=_async_before_agent_callback_noop, ) - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, agent ) spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__) @@ -226,7 +226,7 @@ async def test_run_async_before_agent_callback_bypass_agent( name=f'{request.function.__name__}_test_agent', before_agent_callback=_before_agent_callback_bypass_agent, ) - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, agent ) spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__) @@ -253,7 +253,7 @@ async def test_run_async_with_async_before_agent_callback_bypass_agent( name=f'{request.function.__name__}_test_agent', before_agent_callback=_async_before_agent_callback_bypass_agent, ) - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, agent ) spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__) @@ -394,7 +394,7 @@ async def test_before_agent_callbacks_chain( name=f'{request.function.__name__}_test_agent', before_agent_callback=[mock_cb for mock_cb in mock_cbs], ) - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, agent ) result = [e async for e in agent.run_async(parent_ctx)] @@ -455,7 +455,7 @@ async def test_after_agent_callbacks_chain( name=f'{request.function.__name__}_test_agent', after_agent_callback=[mock_cb for mock_cb in mock_cbs], ) - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, agent ) result = [e async for e in agent.run_async(parent_ctx)] @@ -494,7 +494,7 @@ async def test_run_async_after_agent_callback_noop( name=f'{request.function.__name__}_test_agent', after_agent_callback=_after_agent_callback_noop, ) - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, agent ) spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback') @@ -520,7 +520,7 @@ async def test_run_async_with_async_after_agent_callback_noop( name=f'{request.function.__name__}_test_agent', after_agent_callback=_async_after_agent_callback_noop, ) - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, agent ) spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback') @@ -545,7 +545,7 @@ async def test_run_async_after_agent_callback_append_reply( name=f'{request.function.__name__}_test_agent', after_agent_callback=_after_agent_callback_append_agent_reply, ) - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, agent ) @@ -570,7 +570,7 @@ async def test_run_async_with_async_after_agent_callback_append_reply( name=f'{request.function.__name__}_test_agent', after_agent_callback=_async_after_agent_callback_append_agent_reply, ) - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, agent ) @@ -589,7 +589,7 @@ async def test_run_async_with_async_after_agent_callback_append_reply( @pytest.mark.asyncio async def test_run_async_incomplete_agent(request: pytest.FixtureRequest): agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent') - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, agent ) @@ -600,7 +600,7 @@ async def test_run_async_incomplete_agent(request: pytest.FixtureRequest): @pytest.mark.asyncio async def test_run_live(request: pytest.FixtureRequest): agent = _TestingAgent(name=f'{request.function.__name__}_test_agent') - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, agent ) @@ -614,7 +614,7 @@ async def test_run_live(request: pytest.FixtureRequest): @pytest.mark.asyncio async def test_run_live_with_branch(request: pytest.FixtureRequest): agent = _TestingAgent(name=f'{request.function.__name__}_test_agent') - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, agent, branch='parent_branch' ) @@ -629,7 +629,7 @@ async def test_run_live_with_branch(request: pytest.FixtureRequest): @pytest.mark.asyncio async def test_run_live_incomplete_agent(request: pytest.FixtureRequest): agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent') - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, agent ) diff --git a/tests/unittests/agents/test_llm_agent_fields.py b/tests/unittests/agents/test_llm_agent_fields.py index 914fb79..106e20d 100644 --- a/tests/unittests/agents/test_llm_agent_fields.py +++ b/tests/unittests/agents/test_llm_agent_fields.py @@ -15,7 +15,7 @@ """Unit tests for canonical_xxx fields in LlmAgent.""" from typing import Any -from typing import Optional +from typing import Optional, cast from google.adk.agents.callback_context import CallbackContext from google.adk.agents.invocation_context import InvocationContext @@ -30,11 +30,11 @@ from pydantic import BaseModel import pytest -def _create_readonly_context( +async def _create_readonly_context( agent: LlmAgent, state: Optional[dict[str, Any]] = None ) -> ReadonlyContext: session_service = InMemorySessionService() - session = session_service.create_session( + session = await session_service.create_session( app_name='test_app', user_id='test_user', state=state ) invocation_context = InvocationContext( @@ -77,7 +77,7 @@ def test_canonical_model_inherit(): async def test_canonical_instruction_str(): agent = LlmAgent(name='test_agent', instruction='instruction') - ctx = _create_readonly_context(agent) + ctx = await _create_readonly_context(agent) canonical_instruction = await agent.canonical_instruction(ctx) assert canonical_instruction == 'instruction' @@ -88,7 +88,9 @@ async def test_canonical_instruction(): return f'instruction: {ctx.state["state_var"]}' agent = LlmAgent(name='test_agent', instruction=_instruction_provider) - ctx = _create_readonly_context(agent, state={'state_var': 'state_value'}) + ctx = await _create_readonly_context( + agent, state={'state_var': 'state_value'} + ) canonical_instruction = await agent.canonical_instruction(ctx) assert canonical_instruction == 'instruction: state_value' @@ -99,7 +101,9 @@ async def test_async_canonical_instruction(): return f'instruction: {ctx.state["state_var"]}' agent = LlmAgent(name='test_agent', instruction=_instruction_provider) - ctx = _create_readonly_context(agent, state={'state_var': 'state_value'}) + ctx = await _create_readonly_context( + agent, state={'state_var': 'state_value'} + ) canonical_instruction = await agent.canonical_instruction(ctx) assert canonical_instruction == 'instruction: state_value' @@ -107,10 +111,10 @@ async def test_async_canonical_instruction(): async def test_canonical_global_instruction_str(): agent = LlmAgent(name='test_agent', global_instruction='global instruction') - ctx = _create_readonly_context(agent) + ctx = await _create_readonly_context(agent) - canonical_global_instruction = await agent.canonical_global_instruction(ctx) - assert canonical_global_instruction == 'global instruction' + canonical_instruction = await agent.canonical_global_instruction(ctx) + assert canonical_instruction == 'global instruction' async def test_canonical_global_instruction(): @@ -120,7 +124,9 @@ async def test_canonical_global_instruction(): agent = LlmAgent( name='test_agent', global_instruction=_global_instruction_provider ) - ctx = _create_readonly_context(agent, state={'state_var': 'state_value'}) + ctx = await _create_readonly_context( + agent, state={'state_var': 'state_value'} + ) canonical_global_instruction = await agent.canonical_global_instruction(ctx) assert canonical_global_instruction == 'global instruction: state_value' @@ -133,10 +139,14 @@ async def test_async_canonical_global_instruction(): agent = LlmAgent( name='test_agent', global_instruction=_global_instruction_provider ) - ctx = _create_readonly_context(agent, state={'state_var': 'state_value'}) + ctx = await _create_readonly_context( + agent, state={'state_var': 'state_value'} + ) - canonical_global_instruction = await agent.canonical_global_instruction(ctx) - assert canonical_global_instruction == 'global instruction: state_value' + assert ( + await agent.canonical_global_instruction(ctx) + == 'global instruction: state_value' + ) def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture): diff --git a/tests/unittests/agents/test_loop_agent.py b/tests/unittests/agents/test_loop_agent.py index deafaf2..33ff10f 100644 --- a/tests/unittests/agents/test_loop_agent.py +++ b/tests/unittests/agents/test_loop_agent.py @@ -70,11 +70,11 @@ class _TestingAgentWithEscalateAction(BaseAgent): ) -def _create_parent_invocation_context( +async def _create_parent_invocation_context( test_name: str, agent: BaseAgent ) -> InvocationContext: session_service = InMemorySessionService() - session = session_service.create_session( + session = await session_service.create_session( app_name='test_app', user_id='test_user' ) return InvocationContext( @@ -95,7 +95,7 @@ async def test_run_async(request: pytest.FixtureRequest): agent, ], ) - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, loop_agent ) events = [e async for e in loop_agent.run_async(parent_ctx)] @@ -119,7 +119,7 @@ async def test_run_async_with_escalate_action(request: pytest.FixtureRequest): name=f'{request.function.__name__}_test_loop_agent', sub_agents=[non_escalating_agent, escalating_agent], ) - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, loop_agent ) events = [e async for e in loop_agent.run_async(parent_ctx)] diff --git a/tests/unittests/agents/test_parallel_agent.py b/tests/unittests/agents/test_parallel_agent.py index 4d4ff1c..8b29987 100644 --- a/tests/unittests/agents/test_parallel_agent.py +++ b/tests/unittests/agents/test_parallel_agent.py @@ -47,11 +47,11 @@ class _TestingAgent(BaseAgent): ) -def _create_parent_invocation_context( +async def _create_parent_invocation_context( test_name: str, agent: BaseAgent ) -> InvocationContext: session_service = InMemorySessionService() - session = session_service.create_session( + session = await session_service.create_session( app_name='test_app', user_id='test_user' ) return InvocationContext( @@ -76,7 +76,7 @@ async def test_run_async(request: pytest.FixtureRequest): agent2, ], ) - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, parallel_agent ) events = [e async for e in parallel_agent.run_async(parent_ctx)] diff --git a/tests/unittests/agents/test_sequential_agent.py b/tests/unittests/agents/test_sequential_agent.py index f964737..929f714 100644 --- a/tests/unittests/agents/test_sequential_agent.py +++ b/tests/unittests/agents/test_sequential_agent.py @@ -53,11 +53,11 @@ class _TestingAgent(BaseAgent): ) -def _create_parent_invocation_context( +async def _create_parent_invocation_context( test_name: str, agent: BaseAgent ) -> InvocationContext: session_service = InMemorySessionService() - session = session_service.create_session( + session = await session_service.create_session( app_name='test_app', user_id='test_user' ) return InvocationContext( @@ -79,7 +79,7 @@ async def test_run_async(request: pytest.FixtureRequest): agent_2, ], ) - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, sequential_agent ) events = [e async for e in sequential_agent.run_async(parent_ctx)] @@ -102,7 +102,7 @@ async def test_run_live(request: pytest.FixtureRequest): agent_2, ], ) - parent_ctx = _create_parent_invocation_context( + parent_ctx = await _create_parent_invocation_context( request.function.__name__, sequential_agent ) events = [e async for e in sequential_agent.run_live(parent_ctx)] diff --git a/tests/unittests/cli/utils/test_cli.py b/tests/unittests/cli/utils/test_cli.py index 352e470..50307f3 100644 --- a/tests/unittests/cli/utils/test_cli.py +++ b/tests/unittests/cli/utils/test_cli.py @@ -192,65 +192,22 @@ async def test_run_cli_save_session(fake_agent, tmp_path: Path, monkeypatch: pyt @pytest.mark.asyncio async def test_run_interactively_whitespace_and_exit(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: - """run_interactively should skip blank input, echo once, then exit.""" - # make a session that belongs to dummy agent - svc = cli.InMemorySessionService() - sess = svc.create_session(app_name="dummy", user_id="u") - artifact_service = cli.InMemoryArtifactService() - root_agent = types.SimpleNamespace(name="root") + """run_interactively should skip blank input, echo once, then exit.""" + # make a session that belongs to dummy agent + svc = cli.InMemorySessionService() + sess = await svc.create_session(app_name="dummy", user_id="u") + artifact_service = cli.InMemoryArtifactService() + root_agent = types.SimpleNamespace(name="root") - # fake user input: blank -> 'hello' -> 'exit' - answers = iter([" ", "hello", "exit"]) - monkeypatch.setattr("builtins.input", lambda *_a, **_k: next(answers)) + # fake user input: blank -> 'hello' -> 'exit' + answers = iter([" ", "hello", "exit"]) + monkeypatch.setattr("builtins.input", lambda *_a, **_k: next(answers)) - # capture assisted echo - echoed: list[str] = [] - monkeypatch.setattr(click, "echo", lambda msg: echoed.append(msg)) + # capture assisted echo + echoed: list[str] = [] + monkeypatch.setattr(click, "echo", lambda msg: echoed.append(msg)) - await cli.run_interactively(root_agent, artifact_service, sess, svc) + await cli.run_interactively(root_agent, artifact_service, sess, svc) - # verify: assistant echoed once with 'echo:hello' - assert any("echo:hello" in m for m in echoed) - - -# run_cli (resume branch) -@pytest.mark.asyncio -async def test_run_cli_resume_saved_session(tmp_path: Path, fake_agent, monkeypatch: pytest.MonkeyPatch) -> None: - """run_cli should load previous session, print its events, then re-enter interactive mode.""" - parent_dir, folder = fake_agent - - # stub Session.model_validate_json to return dummy session with two events - user_content = types.SimpleNamespace(parts=[types.SimpleNamespace(text="hi")]) - assistant_content = types.SimpleNamespace(parts=[types.SimpleNamespace(text="hello!")]) - dummy_session = types.SimpleNamespace( - id="sess", - app_name=folder, - user_id="u", - events=[ - types.SimpleNamespace(author="user", content=user_content, partial=False), - types.SimpleNamespace(author="assistant", content=assistant_content, partial=False), - ], - ) - monkeypatch.setattr(cli.Session, "model_validate_json", staticmethod(lambda _s: dummy_session)) - monkeypatch.setattr(cli.InMemorySessionService, "append_event", lambda *_a, **_k: None) - # interactive inputs: immediately 'exit' - monkeypatch.setattr("builtins.input", lambda *_a, **_k: "exit") - - # collect echo output - captured: list[str] = [] - monkeypatch.setattr(click, "echo", lambda m: captured.append(m)) - - saved_path = tmp_path / "prev.session.json" - saved_path.write_text("{}") # contents not used – patched above - - await cli.run_cli( - agent_parent_dir=str(parent_dir), - agent_folder_name=folder, - input_file=None, - saved_session_file=str(saved_path), - save_session=False, - ) - - # ④ ensure both historical messages were printed - assert any("[user]: hi" in m for m in captured) - assert any("[assistant]: hello!" in m for m in captured) + # verify: assistant echoed once with 'echo:hello' + assert any("echo:hello" in m for m in echoed) diff --git a/tests/unittests/flows/llm_flows/_test_examples.py b/tests/unittests/flows/llm_flows/_test_examples.py index 9b51460..29eb718 100644 --- a/tests/unittests/flows/llm_flows/_test_examples.py +++ b/tests/unittests/flows/llm_flows/_test_examples.py @@ -31,7 +31,7 @@ async def test_no_examples(): config=types.GenerateContentConfig(system_instruction=""), ) agent = Agent(model="gemini-1.5-flash", name="agent", examples=[]) - invocation_context = utils.create_invocation_context( + invocation_context = await utils.create_invocation_context( agent=agent, user_content="" ) @@ -69,7 +69,7 @@ async def test_agent_examples(): name="agent", examples=example_list, ) - invocation_context = utils.create_invocation_context( + invocation_context = await utils.create_invocation_context( agent=agent, user_content="test" ) @@ -122,7 +122,7 @@ async def test_agent_base_example_provider(): name="agent", examples=provider, ) - invocation_context = utils.create_invocation_context( + invocation_context = await utils.create_invocation_context( agent=agent, user_content="test" ) diff --git a/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py b/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py index 8ab66da..051bdde 100644 --- a/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py +++ b/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py @@ -81,7 +81,7 @@ async def invoke_tool_with_callbacks( before_tool_callback=before_cb, after_tool_callback=after_cb, ) - invocation_context = utils.create_invocation_context( + invocation_context = await utils.create_invocation_context( agent=agent, user_content="" ) # Build function call event diff --git a/tests/unittests/flows/llm_flows/test_identity.py b/tests/unittests/flows/llm_flows/test_identity.py index 0e88527..564400c 100644 --- a/tests/unittests/flows/llm_flows/test_identity.py +++ b/tests/unittests/flows/llm_flows/test_identity.py @@ -28,7 +28,7 @@ async def test_no_description(): config=types.GenerateContentConfig(system_instruction=""), ) agent = Agent(model="gemini-1.5-flash", name="agent") - invocation_context = utils.create_invocation_context(agent=agent) + invocation_context = await utils.create_invocation_context(agent=agent) async for _ in identity.request_processor.run_async( invocation_context, @@ -52,7 +52,7 @@ async def test_with_description(): name="agent", description="test description", ) - invocation_context = utils.create_invocation_context(agent=agent) + invocation_context = await utils.create_invocation_context(agent=agent) async for _ in identity.request_processor.run_async( invocation_context, diff --git a/tests/unittests/flows/llm_flows/test_instructions.py b/tests/unittests/flows/llm_flows/test_instructions.py index 0d2ac5e..73117d4 100644 --- a/tests/unittests/flows/llm_flows/test_instructions.py +++ b/tests/unittests/flows/llm_flows/test_instructions.py @@ -36,7 +36,7 @@ async def test_build_system_instruction(): {{customer_int }, { non-identifier-float}}, \ {'key1': 'value1'} and {{'key2': 'value2'}}."""), ) - invocation_context = utils.create_invocation_context(agent=agent) + invocation_context = await utils.create_invocation_context(agent=agent) invocation_context.session = Session( app_name="test_app", user_id="test_user", @@ -73,7 +73,7 @@ async def test_function_system_instruction(): name="agent", instruction=build_function_instruction, ) - invocation_context = utils.create_invocation_context(agent=agent) + invocation_context = await utils.create_invocation_context(agent=agent) invocation_context.session = Session( app_name="test_app", user_id="test_user", @@ -111,7 +111,7 @@ async def test_async_function_system_instruction(): name="agent", instruction=build_function_instruction, ) - invocation_context = utils.create_invocation_context(agent=agent) + invocation_context = await utils.create_invocation_context(agent=agent) invocation_context.session = Session( app_name="test_app", user_id="test_user", @@ -147,7 +147,7 @@ async def test_global_system_instruction(): model="gemini-1.5-flash", config=types.GenerateContentConfig(system_instruction=""), ) - invocation_context = utils.create_invocation_context(agent=sub_agent) + invocation_context = await utils.create_invocation_context(agent=sub_agent) invocation_context.session = Session( app_name="test_app", user_id="test_user", @@ -189,7 +189,7 @@ async def test_function_global_system_instruction(): model="gemini-1.5-flash", config=types.GenerateContentConfig(system_instruction=""), ) - invocation_context = utils.create_invocation_context(agent=sub_agent) + invocation_context = await utils.create_invocation_context(agent=sub_agent) invocation_context.session = Session( app_name="test_app", user_id="test_user", @@ -231,7 +231,7 @@ async def test_async_function_global_system_instruction(): model="gemini-1.5-flash", config=types.GenerateContentConfig(system_instruction=""), ) - invocation_context = utils.create_invocation_context(agent=sub_agent) + invocation_context = await utils.create_invocation_context(agent=sub_agent) invocation_context.session = Session( app_name="test_app", user_id="test_user", @@ -263,7 +263,7 @@ async def test_build_system_instruction_with_namespace(): """Use the echo_info tool to echo { customerId }, {app:key}, {user:key}, {a:key}.""" ), ) - invocation_context = utils.create_invocation_context(agent=agent) + invocation_context = await utils.create_invocation_context(agent=agent) invocation_context.session = Session( app_name="test_app", user_id="test_user", diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 158bf5e..6bdc8c9 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -37,26 +37,28 @@ def get_session_service( return InMemorySessionService() +@pytest.mark.asyncio @pytest.mark.parametrize( 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] ) -def test_get_empty_session(service_type): +async def test_get_empty_session(service_type): session_service = get_session_service(service_type) - assert not session_service.get_session( + assert not await session_service.get_session( app_name='my_app', user_id='test_user', session_id='123' ) +@pytest.mark.asyncio @pytest.mark.parametrize( 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] ) -def test_create_get_session(service_type): +async def test_create_get_session(service_type): session_service = get_session_service(service_type) app_name = 'my_app' user_id = 'test_user' state = {'key': 'value'} - session = session_service.create_session( + session = await session_service.create_session( app_name=app_name, user_id=user_id, state=state ) assert session.app_name == app_name @@ -64,50 +66,53 @@ def test_create_get_session(service_type): assert session.id assert session.state == state assert ( - session_service.get_session( + await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) == session ) session_id = session.id - session_service.delete_session( + await session_service.delete_session( app_name=app_name, user_id=user_id, session_id=session_id ) assert ( - not session_service.get_session( + await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) - == session + != session ) +@pytest.mark.asyncio @pytest.mark.parametrize( 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] ) -def test_create_and_list_sessions(service_type): +async def test_create_and_list_sessions(service_type): session_service = get_session_service(service_type) app_name = 'my_app' user_id = 'test_user' session_ids = ['session' + str(i) for i in range(5)] for session_id in session_ids: - session_service.create_session( + await session_service.create_session( app_name=app_name, user_id=user_id, session_id=session_id ) - sessions = session_service.list_sessions( + list_sessions_response = await session_service.list_sessions( app_name=app_name, user_id=user_id - ).sessions + ) + sessions = list_sessions_response.sessions for i in range(len(sessions)): assert sessions[i].id == session_ids[i] +@pytest.mark.asyncio @pytest.mark.parametrize( 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] ) -def test_session_state(service_type): +async def test_session_state(service_type): session_service = get_session_service(service_type) app_name = 'my_app' user_id_1 = 'user1' @@ -118,19 +123,19 @@ def test_session_state(service_type): state_11 = {'key11': 'value11'} state_12 = {'key12': 'value12'} - session_11 = session_service.create_session( + session_11 = await session_service.create_session( app_name=app_name, user_id=user_id_1, state=state_11, session_id=session_id_11, ) - session_service.create_session( + await session_service.create_session( app_name=app_name, user_id=user_id_1, state=state_12, session_id=session_id_12, ) - session_service.create_session( + await session_service.create_session( app_name=app_name, user_id=user_id_2, session_id=session_id_2 ) @@ -149,7 +154,7 @@ def test_session_state(service_type): } ), ) - session_service.append_event(session=session_11, event=event) + await session_service.append_event(session=session_11, event=event) # User and app state is stored, temp state is filtered. assert session_11.state.get('app:key') == 'value' @@ -157,7 +162,7 @@ def test_session_state(service_type): assert session_11.state.get('user:key1') == 'value1' assert not session_11.state.get('temp:key') - session_12 = session_service.get_session( + session_12 = await session_service.get_session( app_name=app_name, user_id=user_id_1, session_id=session_id_12 ) # After getting a new instance, the session_12 got the user and app state, @@ -166,7 +171,7 @@ def test_session_state(service_type): assert not session_12.state.get('temp:key') # The user1's state is not visible to user2, app state is visible - session_2 = session_service.get_session( + session_2 = await session_service.get_session( app_name=app_name, user_id=user_id_2, session_id=session_id_2 ) assert session_2.state.get('app:key') == 'value' @@ -175,7 +180,7 @@ def test_session_state(service_type): assert not session_2.state.get('user:key1') # The change to session_11 is persisted - session_11 = session_service.get_session( + session_11 = await session_service.get_session( app_name=app_name, user_id=user_id_1, session_id=session_id_11 ) assert session_11.state.get('key11') == 'value11_new' @@ -183,10 +188,11 @@ def test_session_state(service_type): assert not session_11.state.get('temp:key') +@pytest.mark.asyncio @pytest.mark.parametrize( 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] ) -def test_create_new_session_will_merge_states(service_type): +async def test_create_new_session_will_merge_states(service_type): session_service = get_session_service(service_type) app_name = 'my_app' user_id = 'user' @@ -194,7 +200,7 @@ def test_create_new_session_will_merge_states(service_type): session_id_2 = 'session2' state_1 = {'key1': 'value1'} - session_1 = session_service.create_session( + session_1 = await session_service.create_session( app_name=app_name, user_id=user_id, state=state_1, session_id=session_id_1 ) @@ -210,7 +216,7 @@ def test_create_new_session_will_merge_states(service_type): } ), ) - session_service.append_event(session=session_1, event=event) + await session_service.append_event(session=session_1, event=event) # User and app state is stored, temp state is filtered. assert session_1.state.get('app:key') == 'value' @@ -218,7 +224,7 @@ def test_create_new_session_will_merge_states(service_type): assert session_1.state.get('user:key1') == 'value1' assert not session_1.state.get('temp:key') - session_2 = session_service.create_session( + session_2 = await session_service.create_session( app_name=app_name, user_id=user_id, state={}, session_id=session_id_2 ) # Session 2 has the persisted states @@ -228,15 +234,18 @@ def test_create_new_session_will_merge_states(service_type): assert not session_2.state.get('temp:key') +@pytest.mark.asyncio @pytest.mark.parametrize( 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] ) -def test_append_event_bytes(service_type): +async def test_append_event_bytes(service_type): session_service = get_session_service(service_type) app_name = 'my_app' user_id = 'user' - session = session_service.create_session(app_name=app_name, user_id=user_id) + session = await session_service.create_session( + app_name=app_name, user_id=user_id + ) event = Event( invocation_id='invocation', author='user', @@ -249,30 +258,34 @@ def test_append_event_bytes(service_type): ], ), ) - session_service.append_event(session=session, event=event) + await session_service.append_event(session=session, event=event) assert session.events[0].content.parts[0] == types.Part.from_bytes( data=b'test_image_data', mime_type='image/png' ) - events = session_service.get_session( + session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session.id - ).events + ) + events = session.events assert len(events) == 1 assert events[0].content.parts[0] == types.Part.from_bytes( data=b'test_image_data', mime_type='image/png' ) +@pytest.mark.asyncio @pytest.mark.parametrize( 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] ) -def test_append_event_complete(service_type): +async def test_append_event_complete(service_type): session_service = get_session_service(service_type) app_name = 'my_app' user_id = 'user' - session = session_service.create_session(app_name=app_name, user_id=user_id) + session = await session_service.create_session( + app_name=app_name, user_id=user_id + ) event = Event( invocation_id='invocation', author='user', @@ -291,65 +304,73 @@ def test_append_event_complete(service_type): error_message='error_message', interrupted=True, ) - session_service.append_event(session=session, event=event) + await session_service.append_event(session=session, event=event) assert ( - session_service.get_session( + await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) == session ) + +@pytest.mark.asyncio @pytest.mark.parametrize('service_type', [SessionServiceType.IN_MEMORY]) -def test_get_session_with_config(service_type): +async def test_get_session_with_config(service_type): session_service = get_session_service(service_type) app_name = 'my_app' user_id = 'user' num_test_events = 5 - session = session_service.create_session(app_name=app_name, user_id=user_id) + session = await session_service.create_session( + app_name=app_name, user_id=user_id + ) for i in range(1, num_test_events + 1): event = Event(author='user', timestamp=i) - session_service.append_event(session, event) + await session_service.append_event(session, event) # No config, expect all events to be returned. - events = session_service.get_session( + session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session.id - ).events + ) + events = session.events assert len(events) == num_test_events # Only expect the most recent 3 events. num_recent_events = 3 config = GetSessionConfig(num_recent_events=num_recent_events) - events = session_service.get_session( + session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session.id, config=config - ).events + ) + events = session.events assert len(events) == num_recent_events assert events[0].timestamp == num_test_events - num_recent_events + 1 # Only expect events after timestamp 4.0 (inclusive), i.e., 2 events. after_timestamp = 4.0 config = GetSessionConfig(after_timestamp=after_timestamp) - events = session_service.get_session( + session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session.id, config=config - ).events + ) + events = session.events assert len(events) == num_test_events - after_timestamp + 1 assert events[0].timestamp == after_timestamp # Expect no events if none are > after_timestamp. way_after_timestamp = num_test_events * 10 config = GetSessionConfig(after_timestamp=way_after_timestamp) - events = session_service.get_session( + session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session.id, config=config - ).events - assert len(events) == 0 + ) + assert not session.events # Both filters applied, i.e., of 3 most recent events, only 2 are after # timestamp 4.0, so expect 2 events. config = GetSessionConfig( after_timestamp=after_timestamp, num_recent_events=num_recent_events ) - events = session_service.get_session( + session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session.id, config=config - ).events + ) + events = session.events assert len(events) == num_test_events - after_timestamp + 1 diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index d56bdf2..592bce2 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -15,7 +15,7 @@ import re import this from typing import Any -import uuid + from dateutil.parser import isoparse from google.adk.events import Event from google.adk.events import EventActions @@ -124,7 +124,9 @@ class MockApiClient: this.session_dict: dict[str, Any] = {} this.event_dict: dict[str, list[Any]] = {} - def request(self, http_method: str, path: str, request_dict: dict[str, Any]): + async def async_request( + self, http_method: str, path: str, request_dict: dict[str, Any] + ): """Mocks the API Client request method.""" if http_method == 'GET': if re.match(SESSION_REGEX, path): @@ -210,46 +212,52 @@ def mock_vertex_ai_session_service(): return service -def test_get_empty_session(): +@pytest.mark.asyncio +async def test_get_empty_session(): session_service = mock_vertex_ai_session_service() with pytest.raises(ValueError) as excinfo: - assert session_service.get_session( + assert await session_service.get_session( app_name='123', user_id='user', session_id='0' ) assert str(excinfo.value) == 'Session not found: 0' -def test_get_and_delete_session(): +@pytest.mark.asyncio +async def test_get_and_delete_session(): session_service = mock_vertex_ai_session_service() assert ( - session_service.get_session( + await session_service.get_session( app_name='123', user_id='user', session_id='1' ) == MOCK_SESSION ) - session_service.delete_session(app_name='123', user_id='user', session_id='1') + await session_service.delete_session( + app_name='123', user_id='user', session_id='1' + ) with pytest.raises(ValueError) as excinfo: - assert session_service.get_session( + assert await session_service.get_session( app_name='123', user_id='user', session_id='1' ) assert str(excinfo.value) == 'Session not found: 1' -def test_list_sessions(): +@pytest.mark.asyncio +async def test_list_sessions(): session_service = mock_vertex_ai_session_service() - sessions = session_service.list_sessions(app_name='123', user_id='user') + sessions = await session_service.list_sessions(app_name='123', user_id='user') assert len(sessions.sessions) == 2 assert sessions.sessions[0].id == '1' assert sessions.sessions[1].id == '2' -def test_create_session(): +@pytest.mark.asyncio +async def test_create_session(): session_service = mock_vertex_ai_session_service() state = {'key': 'value'} - session = session_service.create_session( + session = await session_service.create_session( app_name='123', user_id='user', state=state ) assert session.state == state @@ -258,16 +266,17 @@ def test_create_session(): assert session.last_update_time is not None session_id = session.id - assert session == session_service.get_session( + assert session == await session_service.get_session( app_name='123', user_id='user', session_id=session_id ) -def test_create_session_with_custom_session_id(): +@pytest.mark.asyncio +async def test_create_session_with_custom_session_id(): session_service = mock_vertex_ai_session_service() with pytest.raises(ValueError) as excinfo: - session_service.create_session( + await session_service.create_session( app_name='123', user_id='user', session_id='1' ) assert str(excinfo.value) == ( diff --git a/tests/unittests/tools/test_base_tool.py b/tests/unittests/tools/test_base_tool.py index 13f06d7..d450cc0 100644 --- a/tests/unittests/tools/test_base_tool.py +++ b/tests/unittests/tools/test_base_tool.py @@ -37,9 +37,9 @@ class _TestingTool(BaseTool): return self.declaration -def _create_tool_context() -> ToolContext: +async def _create_tool_context() -> ToolContext: session_service = InMemorySessionService() - session = session_service.create_session( + session = await session_service.create_session( app_name='test_app', user_id='test_user' ) agent = SequentialAgent(name='test_agent') @@ -55,7 +55,7 @@ def _create_tool_context() -> ToolContext: @pytest.mark.asyncio async def test_process_llm_request_no_declaration(): tool = _TestingTool() - tool_context = _create_tool_context() + tool_context = await _create_tool_context() llm_request = LlmRequest() await tool.process_llm_request( @@ -77,7 +77,7 @@ async def test_process_llm_request_with_declaration(): ) tool = _TestingTool(declaration) llm_request = LlmRequest() - tool_context = _create_tool_context() + tool_context = await _create_tool_context() await tool.process_llm_request( tool_context=tool_context, llm_request=llm_request @@ -102,7 +102,7 @@ async def test_process_llm_request_with_builtin_tool(): tools=[types.Tool(google_search=types.GoogleSearch())] ) ) - tool_context = _create_tool_context() + tool_context = await _create_tool_context() await tool.process_llm_request( tool_context=tool_context, llm_request=llm_request @@ -131,7 +131,7 @@ async def test_process_llm_request_with_builtin_tool_and_another_declaration(): ] ) ) - tool_context = _create_tool_context() + tool_context = await _create_tool_context() await tool.process_llm_request( tool_context=tool_context, llm_request=llm_request diff --git a/tests/unittests/utils.py b/tests/unittests/utils.py index 2e74db9..139e0d5 100644 --- a/tests/unittests/utils.py +++ b/tests/unittests/utils.py @@ -56,7 +56,7 @@ class ModelContent(types.Content): super().__init__(role='model', parts=parts) -def create_invocation_context(agent: Agent, user_content: str = ''): +async def create_invocation_context(agent: Agent, user_content: str = ''): invocation_id = 'test_id' artifact_service = InMemoryArtifactService() session_service = InMemorySessionService() @@ -67,7 +67,7 @@ def create_invocation_context(agent: Agent, user_content: str = ''): memory_service=memory_service, invocation_id=invocation_id, agent=agent, - session=session_service.create_session( + session=await session_service.create_session( app_name='test_app', user_id='test_user' ), user_content=types.Content( @@ -141,7 +141,7 @@ class TestInMemoryRunner(AfInMemoryRunner): self, new_message: types.ContentUnion ) -> list[Event]: - session = self.session_service.create_session( + session = await self.session_service.create_session( app_name='InMemoryRunner', user_id='test_user' ) collected_events = [] @@ -172,14 +172,22 @@ class InMemoryRunner: session_service=InMemorySessionService(), memory_service=InMemoryMemoryService(), ) - self.session_id = self.runner.session_service.create_session( - app_name='test_app', user_id='test_user' - ).id + self.session_id = None @property def session(self) -> Session: - return self.runner.session_service.get_session( - app_name='test_app', user_id='test_user', session_id=self.session_id + if not self.session_id: + session = asyncio.run( + self.runner.session_service.create_session( + app_name='test_app', user_id='test_user' + ) + ) + self.session_id = session.id + return session + return asyncio.run( + self.runner.session_service.get_session( + app_name='test_app', user_id='test_user', session_id=self.session_id + ) ) def run(self, new_message: types.ContentUnion) -> list[Event]: @@ -194,9 +202,9 @@ class InMemoryRunner: def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]: collected_responses = [] - async def consume_responses(): + async def consume_responses(session: Session): run_res = self.runner.run_live( - session=self.session, + session=session, live_request_queue=live_request_queue, ) @@ -207,7 +215,8 @@ class InMemoryRunner: return try: - asyncio.run(consume_responses()) + session = self.session + asyncio.run(consume_responses(session)) except asyncio.TimeoutError: print('Returning any partial results collected so far.')