feat! Update session service interface to be async.

Also keep the sync version in the InMemorySessionService as create_session_sync() as a temporary migration option.

PiperOrigin-RevId: 759224250
This commit is contained in:
Shangjie Chen 2025-05-15 11:16:43 -07:00 committed by Copybara-Service
parent d161a2c3f7
commit 5b3204c356
23 changed files with 264 additions and 268 deletions

View File

@ -55,7 +55,7 @@ async def run_input_file(
input_file = InputFile.model_validate_json(f.read())
input_file.state['_time'] = datetime.now()
session = session_service.create_session(
session = await session_service.create_session(
app_name=app_name, user_id=user_id, state=input_file.state
)
for query in input_file.queries:
@ -130,7 +130,7 @@ async def run_cli(
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
agent_module = importlib.import_module(agent_folder_name)
user_id = 'test_user'
session = session_service.create_session(
session = await session_service.create_session(
app_name=agent_folder_name, user_id=user_id
)
root_agent = agent_module.agent.root_agent
@ -145,14 +145,12 @@ async def run_cli(
input_path=input_file,
)
elif saved_session_file:
loaded_session = None
with open(saved_session_file, 'r') as f:
loaded_session = Session.model_validate_json(f.read())
if loaded_session:
for event in loaded_session.events:
session_service.append_event(session, event)
await session_service.append_event(session, event)
content = event.content
if not content or not content.parts or not content.parts[0].text:
continue
@ -181,7 +179,7 @@ async def run_cli(
session_path = f'{agent_module_path}/{session_id}.session.json'
# Fetch the session again to get all the details.
session = session_service.get_session(
session = await session_service.get_session(
app_name=session.app_name,
user_id=session.user_id,
session_id=session.id,

View File

@ -333,10 +333,12 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
response_model_exclude_none=True,
)
def get_session(app_name: str, user_id: str, session_id: str) -> Session:
async def get_session(
app_name: str, user_id: str, session_id: str
) -> Session:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
session = session_service.get_session(
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
if not session:
@ -347,14 +349,15 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions",
response_model_exclude_none=True,
)
def list_sessions(app_name: str, user_id: str) -> list[Session]:
async def list_sessions(app_name: str, user_id: str) -> list[Session]:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
list_sessions_response = await session_service.list_sessions(
app_name=app_name, user_id=user_id
)
return [
session
for session in session_service.list_sessions(
app_name=app_name, user_id=user_id
).sessions
for session in list_sessions_response.sessions
# Remove sessions that were generated as a part of Eval.
if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
]
@ -363,7 +366,7 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
response_model_exclude_none=True,
)
def create_session_with_id(
async def create_session_with_id(
app_name: str,
user_id: str,
session_id: str,
@ -372,7 +375,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
if (
session_service.get_session(
await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
is not None
@ -382,7 +385,7 @@ def get_fast_api_app(
status_code=400, detail=f"Session already exists: {session_id}"
)
logger.info("New session created: %s", session_id)
return session_service.create_session(
return await session_service.create_session(
app_name=app_name, user_id=user_id, state=state, session_id=session_id
)
@ -390,7 +393,7 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions",
response_model_exclude_none=True,
)
def create_session(
async def create_session(
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
@ -398,7 +401,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
logger.info("New session created")
return session_service.create_session(
return await session_service.create_session(
app_name=app_name, user_id=user_id, state=state
)
@ -442,7 +445,7 @@ def get_fast_api_app(
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
):
# Get the session
session = session_service.get_session(
session = await session_service.get_session(
app_name=app_name, user_id=req.user_id, session_id=req.session_id
)
assert session, "Session not found."
@ -530,7 +533,7 @@ def get_fast_api_app(
session_id=eval_case_result.session_id,
)
)
eval_case_result.session_details = session_service.get_session(
eval_case_result.session_details = await session_service.get_session(
app_name=app_name,
user_id=eval_case_result.user_id,
session_id=eval_case_result.session_id,
@ -615,10 +618,10 @@ def get_fast_api_app(
return eval_result_files
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
def delete_session(app_name: str, user_id: str, session_id: str):
async def delete_session(app_name: str, user_id: str, session_id: str):
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
session_service.delete_session(
await session_service.delete_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
@ -713,7 +716,7 @@ def get_fast_api_app(
async def agent_run(req: AgentRunRequest) -> list[Event]:
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name
session = session_service.get_session(
session = await session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id
)
if not session:
@ -735,7 +738,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name
# SSE endpoint
session = session_service.get_session(
session = await session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id
)
if not session:
@ -776,7 +779,7 @@ def get_fast_api_app(
):
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name
session = session_service.get_session(
session = await session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id
)
session_events = session.events if session else []
@ -833,7 +836,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name
session = session_service.get_session(
session = await session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id
)
if not session:

View File

@ -126,7 +126,7 @@ class EvaluationGenerator:
user_id = initial_session.user_id if initial_session else "test_user_id"
session_id = session_id if session_id else str(uuid.uuid4())
_ = session_service.create_session(
_ = await session_service.create_session(
app_name=app_name,
user_id=user_id,
state=initial_session.state if initial_session else {},

View File

@ -173,7 +173,7 @@ class Runner:
The events generated by the agent.
"""
with tracer.start_as_current_span('invocation'):
session = self.session_service.get_session(
session = await self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
)
if not session:
@ -197,7 +197,7 @@ class Runner:
invocation_context.agent = self._find_agent_to_run(session, root_agent)
async for event in invocation_context.agent.run_async(invocation_context):
if not event.partial:
self.session_service.append_event(session=session, event=event)
await self.session_service.append_event(session=session, event=event)
yield event
async def _append_new_message_to_session(
@ -242,7 +242,7 @@ class Runner:
author='user',
content=new_message,
)
self.session_service.append_event(session=session, event=event)
await self.session_service.append_event(session=session, event=event)
async def run_live(
self,
@ -324,7 +324,7 @@ class Runner:
)
async for event in invocation_context.agent.run_live(invocation_context):
self.session_service.append_event(session=session, event=event)
await self.session_service.append_event(session=session, event=event)
yield event
async def close_session(self, session: Session):
@ -335,7 +335,7 @@ class Runner:
"""
if self.memory_service:
await self.memory_service.add_session_to_memory(session)
self.session_service.close_session(session=session)
await self.session_service.close_session(session=session)
def _find_agent_to_run(
self, session: Session, root_agent: BaseAgent

View File

@ -47,7 +47,7 @@ class BaseSessionService(abc.ABC):
"""
@abc.abstractmethod
def create_session(
async def create_session(
self,
*,
app_name: str,
@ -67,10 +67,9 @@ class BaseSessionService(abc.ABC):
Returns:
session: The newly created session instance.
"""
pass
@abc.abstractmethod
def get_session(
async def get_session(
self,
*,
app_name: str,
@ -79,28 +78,24 @@ class BaseSessionService(abc.ABC):
config: Optional[GetSessionConfig] = None,
) -> Optional[Session]:
"""Gets a session."""
pass
@abc.abstractmethod
def list_sessions(
async def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
"""Lists all the sessions."""
pass
@abc.abstractmethod
def delete_session(
async def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
"""Deletes a session."""
pass
def close_session(self, *, session: Session):
async def close_session(self, *, session: Session):
"""Closes a session."""
# TODO: determine whether we want to finalize the session here.
pass
def append_event(self, session: Session, event: Event) -> Event:
async def append_event(self, session: Session, event: Event) -> Event:
"""Appends an event to a session object."""
if event.partial:
return event

View File

@ -283,7 +283,7 @@ class DatabaseSessionService(BaseSessionService):
Base.metadata.create_all(self.db_engine)
@override
def create_session(
async def create_session(
self,
*,
app_name: str,
@ -357,7 +357,7 @@ class DatabaseSessionService(BaseSessionService):
return session
@override
def get_session(
async def get_session(
self,
*,
app_name: str,
@ -431,7 +431,7 @@ class DatabaseSessionService(BaseSessionService):
return session
@override
def list_sessions(
async def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
with self.DatabaseSessionFactory() as sessionFactory:
@ -454,7 +454,7 @@ class DatabaseSessionService(BaseSessionService):
return ListSessionsResponse(sessions=sessions)
@override
def delete_session(
async def delete_session(
self, app_name: str, user_id: str, session_id: str
) -> None:
with self.DatabaseSessionFactory() as sessionFactory:
@ -467,7 +467,7 @@ class DatabaseSessionService(BaseSessionService):
sessionFactory.commit()
@override
def append_event(self, session: Session, event: Event) -> Event:
async def append_event(self, session: Session, event: Event) -> Event:
logger.info(f"Append event: {event} to session {session.id}")
if event.partial:
@ -552,9 +552,10 @@ class DatabaseSessionService(BaseSessionService):
session.last_update_time = storage_session.update_time.timestamp()
# Also update the in-memory session
super().append_event(session=session, event=event)
await super().append_event(session=session, event=event)
return event
def convert_event(event: StorageEvent) -> Event:
"""Converts a storage event to an event."""
return Event(

View File

@ -44,7 +44,7 @@ class InMemorySessionService(BaseSessionService):
self.app_state: dict[str, dict[str, Any]] = {}
@override
def create_session(
async def create_session(
self,
*,
app_name: str,
@ -106,7 +106,7 @@ class InMemorySessionService(BaseSessionService):
return self._merge_state(app_name, user_id, copied_session)
@override
def get_session(
async def get_session(
self,
*,
app_name: str,
@ -193,7 +193,7 @@ class InMemorySessionService(BaseSessionService):
return copied_session
@override
def list_sessions(
async def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
@ -221,7 +221,7 @@ class InMemorySessionService(BaseSessionService):
sessions_without_events.append(copied_session)
return ListSessionsResponse(sessions=sessions_without_events)
def delete_session(
async def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
self._delete_session_impl(
@ -250,16 +250,9 @@ class InMemorySessionService(BaseSessionService):
self.sessions[app_name][user_id].pop(session_id)
@override
def append_event(self, session: Session, event: Event) -> Event:
return self._append_event_impl(session=session, event=event)
def append_event_sync(self, session: Session, event: Event) -> Event:
logger.warning('Deprecated. Please migrate to the async method.')
return self._append_event_impl(session=session, event=event)
def _append_event_impl(self, session: Session, event: Event) -> Event:
async def append_event(self, session: Session, event: Event) -> Event:
# Update the in-memory session.
super().append_event(session=session, event=event)
await super().append_event(session=session, event=event)
session.last_update_time = event.timestamp
# Update the storage session
@ -286,7 +279,7 @@ class InMemorySessionService(BaseSessionService):
] = event.actions.state_delta[key]
storage_session = self.sessions[app_name][user_id].get(session_id)
super().append_event(session=storage_session, event=event)
await super().append_event(session=storage_session, event=event)
storage_session.last_update_time = event.timestamp

View File

@ -48,7 +48,7 @@ class VertexAiSessionService(BaseSessionService):
self.api_client = client._api_client
@override
def create_session(
async def create_session(
self,
*,
app_name: str,
@ -68,7 +68,7 @@ class VertexAiSessionService(BaseSessionService):
if state:
session_json_dict['session_state'] = state
api_response = self.api_client.request(
api_response = await self.api_client.async_request(
http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions',
request_dict=session_json_dict,
@ -80,7 +80,7 @@ class VertexAiSessionService(BaseSessionService):
max_retry_attempt = 5
while max_retry_attempt >= 0:
lro_response = self.api_client.request(
lro_response = await self.api_client.async_request(
http_method='GET',
path=f'operations/{operation_id}',
request_dict={},
@ -93,7 +93,7 @@ class VertexAiSessionService(BaseSessionService):
max_retry_attempt -= 1
# Get session resource
get_session_api_response = self.api_client.request(
get_session_api_response = await self.api_client.async_request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
@ -112,7 +112,7 @@ class VertexAiSessionService(BaseSessionService):
return session
@override
def get_session(
async def get_session(
self,
*,
app_name: str,
@ -123,7 +123,7 @@ class VertexAiSessionService(BaseSessionService):
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
# Get session resource
get_session_api_response = self.api_client.request(
get_session_api_response = await self.api_client.async_request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
@ -141,7 +141,7 @@ class VertexAiSessionService(BaseSessionService):
last_update_time=update_timestamp,
)
list_events_api_response = self.api_client.request(
list_events_api_response = await self.api_client.async_request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
request_dict={},
@ -175,7 +175,7 @@ class VertexAiSessionService(BaseSessionService):
return session
@override
def list_sessions(
async def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
@ -202,23 +202,23 @@ class VertexAiSessionService(BaseSessionService):
sessions.append(session)
return ListSessionsResponse(sessions=sessions)
def delete_session(
async def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
self.api_client.request(
await self.api_client.async_request(
http_method='DELETE',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
)
@override
def append_event(self, session: Session, event: Event) -> Event:
async def append_event(self, session: Session, event: Event) -> Event:
# Update the in-memory session.
super().append_event(session=session, event=event)
await super().append_event(session=session, event=event)
reasoning_engine_id = _parse_reasoning_engine_id(session.app_name)
self.api_client.request(
await self.api_client.async_request(
http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent',
request_dict=_convert_event_to_json(event),

View File

@ -129,7 +129,7 @@ class AgentTool(BaseTool):
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
)
session = runner.session_service.create_session(
session = await runner.session_service.create_session(
app_name=self.agent.name,
user_id='tmp_user',
state=tool_context.state.to_dict(),

View File

@ -110,11 +110,11 @@ class _TestingAgent(BaseAgent):
)
def _create_parent_invocation_context(
async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent, branch: Optional[str] = None
) -> InvocationContext:
session_service = InMemorySessionService()
session = session_service.create_session(
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
return InvocationContext(
@ -134,7 +134,7 @@ def test_invalid_agent_name():
@pytest.mark.asyncio
async def test_run_async(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
@ -148,7 +148,7 @@ async def test_run_async(request: pytest.FixtureRequest):
@pytest.mark.asyncio
async def test_run_async_with_branch(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent, branch='parent_branch'
)
@ -170,7 +170,7 @@ async def test_run_async_before_agent_callback_noop(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_before_agent_callback_noop,
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
@ -198,7 +198,7 @@ async def test_run_async_with_async_before_agent_callback_noop(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_async_before_agent_callback_noop,
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
@ -226,7 +226,7 @@ async def test_run_async_before_agent_callback_bypass_agent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_before_agent_callback_bypass_agent,
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
@ -253,7 +253,7 @@ async def test_run_async_with_async_before_agent_callback_bypass_agent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_async_before_agent_callback_bypass_agent,
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
@ -394,7 +394,7 @@ async def test_before_agent_callbacks_chain(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=[mock_cb for mock_cb in mock_cbs],
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
result = [e async for e in agent.run_async(parent_ctx)]
@ -455,7 +455,7 @@ async def test_after_agent_callbacks_chain(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=[mock_cb for mock_cb in mock_cbs],
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
result = [e async for e in agent.run_async(parent_ctx)]
@ -494,7 +494,7 @@ async def test_run_async_after_agent_callback_noop(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_after_agent_callback_noop,
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
@ -520,7 +520,7 @@ async def test_run_async_with_async_after_agent_callback_noop(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_async_after_agent_callback_noop,
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
@ -545,7 +545,7 @@ async def test_run_async_after_agent_callback_append_reply(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_after_agent_callback_append_agent_reply,
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
@ -570,7 +570,7 @@ async def test_run_async_with_async_after_agent_callback_append_reply(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_async_after_agent_callback_append_agent_reply,
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
@ -589,7 +589,7 @@ async def test_run_async_with_async_after_agent_callback_append_reply(
@pytest.mark.asyncio
async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
@ -600,7 +600,7 @@ async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
@pytest.mark.asyncio
async def test_run_live(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
@ -614,7 +614,7 @@ async def test_run_live(request: pytest.FixtureRequest):
@pytest.mark.asyncio
async def test_run_live_with_branch(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent, branch='parent_branch'
)
@ -629,7 +629,7 @@ async def test_run_live_with_branch(request: pytest.FixtureRequest):
@pytest.mark.asyncio
async def test_run_live_incomplete_agent(request: pytest.FixtureRequest):
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)

View File

@ -15,7 +15,7 @@
"""Unit tests for canonical_xxx fields in LlmAgent."""
from typing import Any
from typing import Optional
from typing import Optional, cast
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext
@ -30,11 +30,11 @@ from pydantic import BaseModel
import pytest
def _create_readonly_context(
async def _create_readonly_context(
agent: LlmAgent, state: Optional[dict[str, Any]] = None
) -> ReadonlyContext:
session_service = InMemorySessionService()
session = session_service.create_session(
session = await session_service.create_session(
app_name='test_app', user_id='test_user', state=state
)
invocation_context = InvocationContext(
@ -77,7 +77,7 @@ def test_canonical_model_inherit():
async def test_canonical_instruction_str():
agent = LlmAgent(name='test_agent', instruction='instruction')
ctx = _create_readonly_context(agent)
ctx = await _create_readonly_context(agent)
canonical_instruction = await agent.canonical_instruction(ctx)
assert canonical_instruction == 'instruction'
@ -88,7 +88,9 @@ async def test_canonical_instruction():
return f'instruction: {ctx.state["state_var"]}'
agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
ctx = await _create_readonly_context(
agent, state={'state_var': 'state_value'}
)
canonical_instruction = await agent.canonical_instruction(ctx)
assert canonical_instruction == 'instruction: state_value'
@ -99,7 +101,9 @@ async def test_async_canonical_instruction():
return f'instruction: {ctx.state["state_var"]}'
agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
ctx = await _create_readonly_context(
agent, state={'state_var': 'state_value'}
)
canonical_instruction = await agent.canonical_instruction(ctx)
assert canonical_instruction == 'instruction: state_value'
@ -107,10 +111,10 @@ async def test_async_canonical_instruction():
async def test_canonical_global_instruction_str():
agent = LlmAgent(name='test_agent', global_instruction='global instruction')
ctx = _create_readonly_context(agent)
ctx = await _create_readonly_context(agent)
canonical_global_instruction = await agent.canonical_global_instruction(ctx)
assert canonical_global_instruction == 'global instruction'
canonical_instruction = await agent.canonical_global_instruction(ctx)
assert canonical_instruction == 'global instruction'
async def test_canonical_global_instruction():
@ -120,7 +124,9 @@ async def test_canonical_global_instruction():
agent = LlmAgent(
name='test_agent', global_instruction=_global_instruction_provider
)
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
ctx = await _create_readonly_context(
agent, state={'state_var': 'state_value'}
)
canonical_global_instruction = await agent.canonical_global_instruction(ctx)
assert canonical_global_instruction == 'global instruction: state_value'
@ -133,10 +139,14 @@ async def test_async_canonical_global_instruction():
agent = LlmAgent(
name='test_agent', global_instruction=_global_instruction_provider
)
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
ctx = await _create_readonly_context(
agent, state={'state_var': 'state_value'}
)
canonical_global_instruction = await agent.canonical_global_instruction(ctx)
assert canonical_global_instruction == 'global instruction: state_value'
assert (
await agent.canonical_global_instruction(ctx)
== 'global instruction: state_value'
)
def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture):

View File

@ -70,11 +70,11 @@ class _TestingAgentWithEscalateAction(BaseAgent):
)
def _create_parent_invocation_context(
async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent
) -> InvocationContext:
session_service = InMemorySessionService()
session = session_service.create_session(
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
return InvocationContext(
@ -95,7 +95,7 @@ async def test_run_async(request: pytest.FixtureRequest):
agent,
],
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, loop_agent
)
events = [e async for e in loop_agent.run_async(parent_ctx)]
@ -119,7 +119,7 @@ async def test_run_async_with_escalate_action(request: pytest.FixtureRequest):
name=f'{request.function.__name__}_test_loop_agent',
sub_agents=[non_escalating_agent, escalating_agent],
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, loop_agent
)
events = [e async for e in loop_agent.run_async(parent_ctx)]

View File

@ -47,11 +47,11 @@ class _TestingAgent(BaseAgent):
)
def _create_parent_invocation_context(
async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent
) -> InvocationContext:
session_service = InMemorySessionService()
session = session_service.create_session(
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
return InvocationContext(
@ -76,7 +76,7 @@ async def test_run_async(request: pytest.FixtureRequest):
agent2,
],
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, parallel_agent
)
events = [e async for e in parallel_agent.run_async(parent_ctx)]

View File

@ -53,11 +53,11 @@ class _TestingAgent(BaseAgent):
)
def _create_parent_invocation_context(
async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent
) -> InvocationContext:
session_service = InMemorySessionService()
session = session_service.create_session(
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
return InvocationContext(
@ -79,7 +79,7 @@ async def test_run_async(request: pytest.FixtureRequest):
agent_2,
],
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, sequential_agent
)
events = [e async for e in sequential_agent.run_async(parent_ctx)]
@ -102,7 +102,7 @@ async def test_run_live(request: pytest.FixtureRequest):
agent_2,
],
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, sequential_agent
)
events = [e async for e in sequential_agent.run_live(parent_ctx)]

View File

@ -192,65 +192,22 @@ async def test_run_cli_save_session(fake_agent, tmp_path: Path, monkeypatch: pyt
@pytest.mark.asyncio
async def test_run_interactively_whitespace_and_exit(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""run_interactively should skip blank input, echo once, then exit."""
# make a session that belongs to dummy agent
svc = cli.InMemorySessionService()
sess = svc.create_session(app_name="dummy", user_id="u")
artifact_service = cli.InMemoryArtifactService()
root_agent = types.SimpleNamespace(name="root")
"""run_interactively should skip blank input, echo once, then exit."""
# make a session that belongs to dummy agent
svc = cli.InMemorySessionService()
sess = await svc.create_session(app_name="dummy", user_id="u")
artifact_service = cli.InMemoryArtifactService()
root_agent = types.SimpleNamespace(name="root")
# fake user input: blank -> 'hello' -> 'exit'
answers = iter([" ", "hello", "exit"])
monkeypatch.setattr("builtins.input", lambda *_a, **_k: next(answers))
# fake user input: blank -> 'hello' -> 'exit'
answers = iter([" ", "hello", "exit"])
monkeypatch.setattr("builtins.input", lambda *_a, **_k: next(answers))
# capture assisted echo
echoed: list[str] = []
monkeypatch.setattr(click, "echo", lambda msg: echoed.append(msg))
# capture assisted echo
echoed: list[str] = []
monkeypatch.setattr(click, "echo", lambda msg: echoed.append(msg))
await cli.run_interactively(root_agent, artifact_service, sess, svc)
await cli.run_interactively(root_agent, artifact_service, sess, svc)
# verify: assistant echoed once with 'echo:hello'
assert any("echo:hello" in m for m in echoed)
# run_cli (resume branch)
@pytest.mark.asyncio
async def test_run_cli_resume_saved_session(tmp_path: Path, fake_agent, monkeypatch: pytest.MonkeyPatch) -> None:
"""run_cli should load previous session, print its events, then re-enter interactive mode."""
parent_dir, folder = fake_agent
# stub Session.model_validate_json to return dummy session with two events
user_content = types.SimpleNamespace(parts=[types.SimpleNamespace(text="hi")])
assistant_content = types.SimpleNamespace(parts=[types.SimpleNamespace(text="hello!")])
dummy_session = types.SimpleNamespace(
id="sess",
app_name=folder,
user_id="u",
events=[
types.SimpleNamespace(author="user", content=user_content, partial=False),
types.SimpleNamespace(author="assistant", content=assistant_content, partial=False),
],
)
monkeypatch.setattr(cli.Session, "model_validate_json", staticmethod(lambda _s: dummy_session))
monkeypatch.setattr(cli.InMemorySessionService, "append_event", lambda *_a, **_k: None)
# interactive inputs: immediately 'exit'
monkeypatch.setattr("builtins.input", lambda *_a, **_k: "exit")
# collect echo output
captured: list[str] = []
monkeypatch.setattr(click, "echo", lambda m: captured.append(m))
saved_path = tmp_path / "prev.session.json"
saved_path.write_text("{}") # contents not used patched above
await cli.run_cli(
agent_parent_dir=str(parent_dir),
agent_folder_name=folder,
input_file=None,
saved_session_file=str(saved_path),
save_session=False,
)
# ④ ensure both historical messages were printed
assert any("[user]: hi" in m for m in captured)
assert any("[assistant]: hello!" in m for m in captured)
# verify: assistant echoed once with 'echo:hello'
assert any("echo:hello" in m for m in echoed)

View File

@ -31,7 +31,7 @@ async def test_no_examples():
config=types.GenerateContentConfig(system_instruction=""),
)
agent = Agent(model="gemini-1.5-flash", name="agent", examples=[])
invocation_context = utils.create_invocation_context(
invocation_context = await utils.create_invocation_context(
agent=agent, user_content=""
)
@ -69,7 +69,7 @@ async def test_agent_examples():
name="agent",
examples=example_list,
)
invocation_context = utils.create_invocation_context(
invocation_context = await utils.create_invocation_context(
agent=agent, user_content="test"
)
@ -122,7 +122,7 @@ async def test_agent_base_example_provider():
name="agent",
examples=provider,
)
invocation_context = utils.create_invocation_context(
invocation_context = await utils.create_invocation_context(
agent=agent, user_content="test"
)

View File

@ -81,7 +81,7 @@ async def invoke_tool_with_callbacks(
before_tool_callback=before_cb,
after_tool_callback=after_cb,
)
invocation_context = utils.create_invocation_context(
invocation_context = await utils.create_invocation_context(
agent=agent, user_content=""
)
# Build function call event

View File

@ -28,7 +28,7 @@ async def test_no_description():
config=types.GenerateContentConfig(system_instruction=""),
)
agent = Agent(model="gemini-1.5-flash", name="agent")
invocation_context = utils.create_invocation_context(agent=agent)
invocation_context = await utils.create_invocation_context(agent=agent)
async for _ in identity.request_processor.run_async(
invocation_context,
@ -52,7 +52,7 @@ async def test_with_description():
name="agent",
description="test description",
)
invocation_context = utils.create_invocation_context(agent=agent)
invocation_context = await utils.create_invocation_context(agent=agent)
async for _ in identity.request_processor.run_async(
invocation_context,

View File

@ -36,7 +36,7 @@ async def test_build_system_instruction():
{{customer_int }, { non-identifier-float}}, \
{'key1': 'value1'} and {{'key2': 'value2'}}."""),
)
invocation_context = utils.create_invocation_context(agent=agent)
invocation_context = await utils.create_invocation_context(agent=agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
@ -73,7 +73,7 @@ async def test_function_system_instruction():
name="agent",
instruction=build_function_instruction,
)
invocation_context = utils.create_invocation_context(agent=agent)
invocation_context = await utils.create_invocation_context(agent=agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
@ -111,7 +111,7 @@ async def test_async_function_system_instruction():
name="agent",
instruction=build_function_instruction,
)
invocation_context = utils.create_invocation_context(agent=agent)
invocation_context = await utils.create_invocation_context(agent=agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
@ -147,7 +147,7 @@ async def test_global_system_instruction():
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
invocation_context = utils.create_invocation_context(agent=sub_agent)
invocation_context = await utils.create_invocation_context(agent=sub_agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
@ -189,7 +189,7 @@ async def test_function_global_system_instruction():
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
invocation_context = utils.create_invocation_context(agent=sub_agent)
invocation_context = await utils.create_invocation_context(agent=sub_agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
@ -231,7 +231,7 @@ async def test_async_function_global_system_instruction():
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
invocation_context = utils.create_invocation_context(agent=sub_agent)
invocation_context = await utils.create_invocation_context(agent=sub_agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
@ -263,7 +263,7 @@ async def test_build_system_instruction_with_namespace():
"""Use the echo_info tool to echo { customerId }, {app:key}, {user:key}, {a:key}."""
),
)
invocation_context = utils.create_invocation_context(agent=agent)
invocation_context = await utils.create_invocation_context(agent=agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",

View File

@ -37,26 +37,28 @@ def get_session_service(
return InMemorySessionService()
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
def test_get_empty_session(service_type):
async def test_get_empty_session(service_type):
session_service = get_session_service(service_type)
assert not session_service.get_session(
assert not await session_service.get_session(
app_name='my_app', user_id='test_user', session_id='123'
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
def test_create_get_session(service_type):
async def test_create_get_session(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'test_user'
state = {'key': 'value'}
session = session_service.create_session(
session = await session_service.create_session(
app_name=app_name, user_id=user_id, state=state
)
assert session.app_name == app_name
@ -64,50 +66,53 @@ def test_create_get_session(service_type):
assert session.id
assert session.state == state
assert (
session_service.get_session(
await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
== session
)
session_id = session.id
session_service.delete_session(
await session_service.delete_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
assert (
not session_service.get_session(
await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
== session
!= session
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
def test_create_and_list_sessions(service_type):
async def test_create_and_list_sessions(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'test_user'
session_ids = ['session' + str(i) for i in range(5)]
for session_id in session_ids:
session_service.create_session(
await session_service.create_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
sessions = session_service.list_sessions(
list_sessions_response = await session_service.list_sessions(
app_name=app_name, user_id=user_id
).sessions
)
sessions = list_sessions_response.sessions
for i in range(len(sessions)):
assert sessions[i].id == session_ids[i]
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
def test_session_state(service_type):
async def test_session_state(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id_1 = 'user1'
@ -118,19 +123,19 @@ def test_session_state(service_type):
state_11 = {'key11': 'value11'}
state_12 = {'key12': 'value12'}
session_11 = session_service.create_session(
session_11 = await session_service.create_session(
app_name=app_name,
user_id=user_id_1,
state=state_11,
session_id=session_id_11,
)
session_service.create_session(
await session_service.create_session(
app_name=app_name,
user_id=user_id_1,
state=state_12,
session_id=session_id_12,
)
session_service.create_session(
await session_service.create_session(
app_name=app_name, user_id=user_id_2, session_id=session_id_2
)
@ -149,7 +154,7 @@ def test_session_state(service_type):
}
),
)
session_service.append_event(session=session_11, event=event)
await session_service.append_event(session=session_11, event=event)
# User and app state is stored, temp state is filtered.
assert session_11.state.get('app:key') == 'value'
@ -157,7 +162,7 @@ def test_session_state(service_type):
assert session_11.state.get('user:key1') == 'value1'
assert not session_11.state.get('temp:key')
session_12 = session_service.get_session(
session_12 = await session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session_id_12
)
# After getting a new instance, the session_12 got the user and app state,
@ -166,7 +171,7 @@ def test_session_state(service_type):
assert not session_12.state.get('temp:key')
# The user1's state is not visible to user2, app state is visible
session_2 = session_service.get_session(
session_2 = await session_service.get_session(
app_name=app_name, user_id=user_id_2, session_id=session_id_2
)
assert session_2.state.get('app:key') == 'value'
@ -175,7 +180,7 @@ def test_session_state(service_type):
assert not session_2.state.get('user:key1')
# The change to session_11 is persisted
session_11 = session_service.get_session(
session_11 = await session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session_id_11
)
assert session_11.state.get('key11') == 'value11_new'
@ -183,10 +188,11 @@ def test_session_state(service_type):
assert not session_11.state.get('temp:key')
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
def test_create_new_session_will_merge_states(service_type):
async def test_create_new_session_will_merge_states(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
@ -194,7 +200,7 @@ def test_create_new_session_will_merge_states(service_type):
session_id_2 = 'session2'
state_1 = {'key1': 'value1'}
session_1 = session_service.create_session(
session_1 = await session_service.create_session(
app_name=app_name, user_id=user_id, state=state_1, session_id=session_id_1
)
@ -210,7 +216,7 @@ def test_create_new_session_will_merge_states(service_type):
}
),
)
session_service.append_event(session=session_1, event=event)
await session_service.append_event(session=session_1, event=event)
# User and app state is stored, temp state is filtered.
assert session_1.state.get('app:key') == 'value'
@ -218,7 +224,7 @@ def test_create_new_session_will_merge_states(service_type):
assert session_1.state.get('user:key1') == 'value1'
assert not session_1.state.get('temp:key')
session_2 = session_service.create_session(
session_2 = await session_service.create_session(
app_name=app_name, user_id=user_id, state={}, session_id=session_id_2
)
# Session 2 has the persisted states
@ -228,15 +234,18 @@ def test_create_new_session_will_merge_states(service_type):
assert not session_2.state.get('temp:key')
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
def test_append_event_bytes(service_type):
async def test_append_event_bytes(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
session = session_service.create_session(app_name=app_name, user_id=user_id)
session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
invocation_id='invocation',
author='user',
@ -249,30 +258,34 @@ def test_append_event_bytes(service_type):
],
),
)
session_service.append_event(session=session, event=event)
await session_service.append_event(session=session, event=event)
assert session.events[0].content.parts[0] == types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png'
)
events = session_service.get_session(
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
).events
)
events = session.events
assert len(events) == 1
assert events[0].content.parts[0] == types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png'
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
def test_append_event_complete(service_type):
async def test_append_event_complete(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
session = session_service.create_session(app_name=app_name, user_id=user_id)
session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
invocation_id='invocation',
author='user',
@ -291,65 +304,73 @@ def test_append_event_complete(service_type):
error_message='error_message',
interrupted=True,
)
session_service.append_event(session=session, event=event)
await session_service.append_event(session=session, event=event)
assert (
session_service.get_session(
await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
== session
)
@pytest.mark.asyncio
@pytest.mark.parametrize('service_type', [SessionServiceType.IN_MEMORY])
def test_get_session_with_config(service_type):
async def test_get_session_with_config(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
num_test_events = 5
session = session_service.create_session(app_name=app_name, user_id=user_id)
session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
for i in range(1, num_test_events + 1):
event = Event(author='user', timestamp=i)
session_service.append_event(session, event)
await session_service.append_event(session, event)
# No config, expect all events to be returned.
events = session_service.get_session(
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
).events
)
events = session.events
assert len(events) == num_test_events
# Only expect the most recent 3 events.
num_recent_events = 3
config = GetSessionConfig(num_recent_events=num_recent_events)
events = session_service.get_session(
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config
).events
)
events = session.events
assert len(events) == num_recent_events
assert events[0].timestamp == num_test_events - num_recent_events + 1
# Only expect events after timestamp 4.0 (inclusive), i.e., 2 events.
after_timestamp = 4.0
config = GetSessionConfig(after_timestamp=after_timestamp)
events = session_service.get_session(
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config
).events
)
events = session.events
assert len(events) == num_test_events - after_timestamp + 1
assert events[0].timestamp == after_timestamp
# Expect no events if none are > after_timestamp.
way_after_timestamp = num_test_events * 10
config = GetSessionConfig(after_timestamp=way_after_timestamp)
events = session_service.get_session(
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config
).events
assert len(events) == 0
)
assert not session.events
# Both filters applied, i.e., of 3 most recent events, only 2 are after
# timestamp 4.0, so expect 2 events.
config = GetSessionConfig(
after_timestamp=after_timestamp, num_recent_events=num_recent_events
)
events = session_service.get_session(
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config
).events
)
events = session.events
assert len(events) == num_test_events - after_timestamp + 1

View File

@ -15,7 +15,7 @@
import re
import this
from typing import Any
import uuid
from dateutil.parser import isoparse
from google.adk.events import Event
from google.adk.events import EventActions
@ -124,7 +124,9 @@ class MockApiClient:
this.session_dict: dict[str, Any] = {}
this.event_dict: dict[str, list[Any]] = {}
def request(self, http_method: str, path: str, request_dict: dict[str, Any]):
async def async_request(
self, http_method: str, path: str, request_dict: dict[str, Any]
):
"""Mocks the API Client request method."""
if http_method == 'GET':
if re.match(SESSION_REGEX, path):
@ -210,46 +212,52 @@ def mock_vertex_ai_session_service():
return service
def test_get_empty_session():
@pytest.mark.asyncio
async def test_get_empty_session():
session_service = mock_vertex_ai_session_service()
with pytest.raises(ValueError) as excinfo:
assert session_service.get_session(
assert await session_service.get_session(
app_name='123', user_id='user', session_id='0'
)
assert str(excinfo.value) == 'Session not found: 0'
def test_get_and_delete_session():
@pytest.mark.asyncio
async def test_get_and_delete_session():
session_service = mock_vertex_ai_session_service()
assert (
session_service.get_session(
await session_service.get_session(
app_name='123', user_id='user', session_id='1'
)
== MOCK_SESSION
)
session_service.delete_session(app_name='123', user_id='user', session_id='1')
await session_service.delete_session(
app_name='123', user_id='user', session_id='1'
)
with pytest.raises(ValueError) as excinfo:
assert session_service.get_session(
assert await session_service.get_session(
app_name='123', user_id='user', session_id='1'
)
assert str(excinfo.value) == 'Session not found: 1'
def test_list_sessions():
@pytest.mark.asyncio
async def test_list_sessions():
session_service = mock_vertex_ai_session_service()
sessions = session_service.list_sessions(app_name='123', user_id='user')
sessions = await session_service.list_sessions(app_name='123', user_id='user')
assert len(sessions.sessions) == 2
assert sessions.sessions[0].id == '1'
assert sessions.sessions[1].id == '2'
def test_create_session():
@pytest.mark.asyncio
async def test_create_session():
session_service = mock_vertex_ai_session_service()
state = {'key': 'value'}
session = session_service.create_session(
session = await session_service.create_session(
app_name='123', user_id='user', state=state
)
assert session.state == state
@ -258,16 +266,17 @@ def test_create_session():
assert session.last_update_time is not None
session_id = session.id
assert session == session_service.get_session(
assert session == await session_service.get_session(
app_name='123', user_id='user', session_id=session_id
)
def test_create_session_with_custom_session_id():
@pytest.mark.asyncio
async def test_create_session_with_custom_session_id():
session_service = mock_vertex_ai_session_service()
with pytest.raises(ValueError) as excinfo:
session_service.create_session(
await session_service.create_session(
app_name='123', user_id='user', session_id='1'
)
assert str(excinfo.value) == (

View File

@ -37,9 +37,9 @@ class _TestingTool(BaseTool):
return self.declaration
def _create_tool_context() -> ToolContext:
async def _create_tool_context() -> ToolContext:
session_service = InMemorySessionService()
session = session_service.create_session(
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
agent = SequentialAgent(name='test_agent')
@ -55,7 +55,7 @@ def _create_tool_context() -> ToolContext:
@pytest.mark.asyncio
async def test_process_llm_request_no_declaration():
tool = _TestingTool()
tool_context = _create_tool_context()
tool_context = await _create_tool_context()
llm_request = LlmRequest()
await tool.process_llm_request(
@ -77,7 +77,7 @@ async def test_process_llm_request_with_declaration():
)
tool = _TestingTool(declaration)
llm_request = LlmRequest()
tool_context = _create_tool_context()
tool_context = await _create_tool_context()
await tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request
@ -102,7 +102,7 @@ async def test_process_llm_request_with_builtin_tool():
tools=[types.Tool(google_search=types.GoogleSearch())]
)
)
tool_context = _create_tool_context()
tool_context = await _create_tool_context()
await tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request
@ -131,7 +131,7 @@ async def test_process_llm_request_with_builtin_tool_and_another_declaration():
]
)
)
tool_context = _create_tool_context()
tool_context = await _create_tool_context()
await tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request

View File

@ -56,7 +56,7 @@ class ModelContent(types.Content):
super().__init__(role='model', parts=parts)
def create_invocation_context(agent: Agent, user_content: str = ''):
async def create_invocation_context(agent: Agent, user_content: str = ''):
invocation_id = 'test_id'
artifact_service = InMemoryArtifactService()
session_service = InMemorySessionService()
@ -67,7 +67,7 @@ def create_invocation_context(agent: Agent, user_content: str = ''):
memory_service=memory_service,
invocation_id=invocation_id,
agent=agent,
session=session_service.create_session(
session=await session_service.create_session(
app_name='test_app', user_id='test_user'
),
user_content=types.Content(
@ -141,7 +141,7 @@ class TestInMemoryRunner(AfInMemoryRunner):
self, new_message: types.ContentUnion
) -> list[Event]:
session = self.session_service.create_session(
session = await self.session_service.create_session(
app_name='InMemoryRunner', user_id='test_user'
)
collected_events = []
@ -172,14 +172,22 @@ class InMemoryRunner:
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
)
self.session_id = self.runner.session_service.create_session(
app_name='test_app', user_id='test_user'
).id
self.session_id = None
@property
def session(self) -> Session:
return self.runner.session_service.get_session(
app_name='test_app', user_id='test_user', session_id=self.session_id
if not self.session_id:
session = asyncio.run(
self.runner.session_service.create_session(
app_name='test_app', user_id='test_user'
)
)
self.session_id = session.id
return session
return asyncio.run(
self.runner.session_service.get_session(
app_name='test_app', user_id='test_user', session_id=self.session_id
)
)
def run(self, new_message: types.ContentUnion) -> list[Event]:
@ -194,9 +202,9 @@ class InMemoryRunner:
def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]:
collected_responses = []
async def consume_responses():
async def consume_responses(session: Session):
run_res = self.runner.run_live(
session=self.session,
session=session,
live_request_queue=live_request_queue,
)
@ -207,7 +215,8 @@ class InMemoryRunner:
return
try:
asyncio.run(consume_responses())
session = self.session
asyncio.run(consume_responses(session))
except asyncio.TimeoutError:
print('Returning any partial results collected so far.')