mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 11:22:22 -06:00
feat! Update session service interface to be async.
Also keep the sync version in the InMemorySessionService as create_session_sync() as a temporary migration option. PiperOrigin-RevId: 759224250
This commit is contained in:
committed by
Copybara-Service
parent
d161a2c3f7
commit
5b3204c356
@@ -110,11 +110,11 @@ class _TestingAgent(BaseAgent):
|
||||
)
|
||||
|
||||
|
||||
def _create_parent_invocation_context(
|
||||
async def _create_parent_invocation_context(
|
||||
test_name: str, agent: BaseAgent, branch: Optional[str] = None
|
||||
) -> InvocationContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
session = await session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
return InvocationContext(
|
||||
@@ -134,7 +134,7 @@ def test_invalid_agent_name():
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async(request: pytest.FixtureRequest):
|
||||
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
|
||||
@@ -148,7 +148,7 @@ async def test_run_async(request: pytest.FixtureRequest):
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_branch(request: pytest.FixtureRequest):
|
||||
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent, branch='parent_branch'
|
||||
)
|
||||
|
||||
@@ -170,7 +170,7 @@ async def test_run_async_before_agent_callback_noop(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
before_agent_callback=_before_agent_callback_noop,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
|
||||
@@ -198,7 +198,7 @@ async def test_run_async_with_async_before_agent_callback_noop(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
before_agent_callback=_async_before_agent_callback_noop,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
|
||||
@@ -226,7 +226,7 @@ async def test_run_async_before_agent_callback_bypass_agent(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
before_agent_callback=_before_agent_callback_bypass_agent,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
|
||||
@@ -253,7 +253,7 @@ async def test_run_async_with_async_before_agent_callback_bypass_agent(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
before_agent_callback=_async_before_agent_callback_bypass_agent,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
|
||||
@@ -394,7 +394,7 @@ async def test_before_agent_callbacks_chain(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
before_agent_callback=[mock_cb for mock_cb in mock_cbs],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
result = [e async for e in agent.run_async(parent_ctx)]
|
||||
@@ -455,7 +455,7 @@ async def test_after_agent_callbacks_chain(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
after_agent_callback=[mock_cb for mock_cb in mock_cbs],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
result = [e async for e in agent.run_async(parent_ctx)]
|
||||
@@ -494,7 +494,7 @@ async def test_run_async_after_agent_callback_noop(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
after_agent_callback=_after_agent_callback_noop,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
|
||||
@@ -520,7 +520,7 @@ async def test_run_async_with_async_after_agent_callback_noop(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
after_agent_callback=_async_after_agent_callback_noop,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
|
||||
@@ -545,7 +545,7 @@ async def test_run_async_after_agent_callback_append_reply(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
after_agent_callback=_after_agent_callback_append_agent_reply,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
|
||||
@@ -570,7 +570,7 @@ async def test_run_async_with_async_after_agent_callback_append_reply(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
after_agent_callback=_async_after_agent_callback_append_agent_reply,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
|
||||
@@ -589,7 +589,7 @@ async def test_run_async_with_async_after_agent_callback_append_reply(
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
|
||||
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
|
||||
@@ -600,7 +600,7 @@ async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_live(request: pytest.FixtureRequest):
|
||||
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
|
||||
@@ -614,7 +614,7 @@ async def test_run_live(request: pytest.FixtureRequest):
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_live_with_branch(request: pytest.FixtureRequest):
|
||||
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent, branch='parent_branch'
|
||||
)
|
||||
|
||||
@@ -629,7 +629,7 @@ async def test_run_live_with_branch(request: pytest.FixtureRequest):
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_live_incomplete_agent(request: pytest.FixtureRequest):
|
||||
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"""Unit tests for canonical_xxx fields in LlmAgent."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
@@ -30,11 +30,11 @@ from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
|
||||
def _create_readonly_context(
|
||||
async def _create_readonly_context(
|
||||
agent: LlmAgent, state: Optional[dict[str, Any]] = None
|
||||
) -> ReadonlyContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
session = await session_service.create_session(
|
||||
app_name='test_app', user_id='test_user', state=state
|
||||
)
|
||||
invocation_context = InvocationContext(
|
||||
@@ -77,7 +77,7 @@ def test_canonical_model_inherit():
|
||||
|
||||
async def test_canonical_instruction_str():
|
||||
agent = LlmAgent(name='test_agent', instruction='instruction')
|
||||
ctx = _create_readonly_context(agent)
|
||||
ctx = await _create_readonly_context(agent)
|
||||
|
||||
canonical_instruction = await agent.canonical_instruction(ctx)
|
||||
assert canonical_instruction == 'instruction'
|
||||
@@ -88,7 +88,9 @@ async def test_canonical_instruction():
|
||||
return f'instruction: {ctx.state["state_var"]}'
|
||||
|
||||
agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
|
||||
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
|
||||
ctx = await _create_readonly_context(
|
||||
agent, state={'state_var': 'state_value'}
|
||||
)
|
||||
|
||||
canonical_instruction = await agent.canonical_instruction(ctx)
|
||||
assert canonical_instruction == 'instruction: state_value'
|
||||
@@ -99,7 +101,9 @@ async def test_async_canonical_instruction():
|
||||
return f'instruction: {ctx.state["state_var"]}'
|
||||
|
||||
agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
|
||||
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
|
||||
ctx = await _create_readonly_context(
|
||||
agent, state={'state_var': 'state_value'}
|
||||
)
|
||||
|
||||
canonical_instruction = await agent.canonical_instruction(ctx)
|
||||
assert canonical_instruction == 'instruction: state_value'
|
||||
@@ -107,10 +111,10 @@ async def test_async_canonical_instruction():
|
||||
|
||||
async def test_canonical_global_instruction_str():
|
||||
agent = LlmAgent(name='test_agent', global_instruction='global instruction')
|
||||
ctx = _create_readonly_context(agent)
|
||||
ctx = await _create_readonly_context(agent)
|
||||
|
||||
canonical_global_instruction = await agent.canonical_global_instruction(ctx)
|
||||
assert canonical_global_instruction == 'global instruction'
|
||||
canonical_instruction = await agent.canonical_global_instruction(ctx)
|
||||
assert canonical_instruction == 'global instruction'
|
||||
|
||||
|
||||
async def test_canonical_global_instruction():
|
||||
@@ -120,7 +124,9 @@ async def test_canonical_global_instruction():
|
||||
agent = LlmAgent(
|
||||
name='test_agent', global_instruction=_global_instruction_provider
|
||||
)
|
||||
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
|
||||
ctx = await _create_readonly_context(
|
||||
agent, state={'state_var': 'state_value'}
|
||||
)
|
||||
|
||||
canonical_global_instruction = await agent.canonical_global_instruction(ctx)
|
||||
assert canonical_global_instruction == 'global instruction: state_value'
|
||||
@@ -133,10 +139,14 @@ async def test_async_canonical_global_instruction():
|
||||
agent = LlmAgent(
|
||||
name='test_agent', global_instruction=_global_instruction_provider
|
||||
)
|
||||
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
|
||||
ctx = await _create_readonly_context(
|
||||
agent, state={'state_var': 'state_value'}
|
||||
)
|
||||
|
||||
canonical_global_instruction = await agent.canonical_global_instruction(ctx)
|
||||
assert canonical_global_instruction == 'global instruction: state_value'
|
||||
assert (
|
||||
await agent.canonical_global_instruction(ctx)
|
||||
== 'global instruction: state_value'
|
||||
)
|
||||
|
||||
|
||||
def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture):
|
||||
|
||||
@@ -70,11 +70,11 @@ class _TestingAgentWithEscalateAction(BaseAgent):
|
||||
)
|
||||
|
||||
|
||||
def _create_parent_invocation_context(
|
||||
async def _create_parent_invocation_context(
|
||||
test_name: str, agent: BaseAgent
|
||||
) -> InvocationContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
session = await session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
return InvocationContext(
|
||||
@@ -95,7 +95,7 @@ async def test_run_async(request: pytest.FixtureRequest):
|
||||
agent,
|
||||
],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, loop_agent
|
||||
)
|
||||
events = [e async for e in loop_agent.run_async(parent_ctx)]
|
||||
@@ -119,7 +119,7 @@ async def test_run_async_with_escalate_action(request: pytest.FixtureRequest):
|
||||
name=f'{request.function.__name__}_test_loop_agent',
|
||||
sub_agents=[non_escalating_agent, escalating_agent],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, loop_agent
|
||||
)
|
||||
events = [e async for e in loop_agent.run_async(parent_ctx)]
|
||||
|
||||
@@ -47,11 +47,11 @@ class _TestingAgent(BaseAgent):
|
||||
)
|
||||
|
||||
|
||||
def _create_parent_invocation_context(
|
||||
async def _create_parent_invocation_context(
|
||||
test_name: str, agent: BaseAgent
|
||||
) -> InvocationContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
session = await session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
return InvocationContext(
|
||||
@@ -76,7 +76,7 @@ async def test_run_async(request: pytest.FixtureRequest):
|
||||
agent2,
|
||||
],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, parallel_agent
|
||||
)
|
||||
events = [e async for e in parallel_agent.run_async(parent_ctx)]
|
||||
|
||||
@@ -53,11 +53,11 @@ class _TestingAgent(BaseAgent):
|
||||
)
|
||||
|
||||
|
||||
def _create_parent_invocation_context(
|
||||
async def _create_parent_invocation_context(
|
||||
test_name: str, agent: BaseAgent
|
||||
) -> InvocationContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
session = await session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
return InvocationContext(
|
||||
@@ -79,7 +79,7 @@ async def test_run_async(request: pytest.FixtureRequest):
|
||||
agent_2,
|
||||
],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, sequential_agent
|
||||
)
|
||||
events = [e async for e in sequential_agent.run_async(parent_ctx)]
|
||||
@@ -102,7 +102,7 @@ async def test_run_live(request: pytest.FixtureRequest):
|
||||
agent_2,
|
||||
],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, sequential_agent
|
||||
)
|
||||
events = [e async for e in sequential_agent.run_live(parent_ctx)]
|
||||
|
||||
@@ -192,65 +192,22 @@ async def test_run_cli_save_session(fake_agent, tmp_path: Path, monkeypatch: pyt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_interactively_whitespace_and_exit(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""run_interactively should skip blank input, echo once, then exit."""
|
||||
# make a session that belongs to dummy agent
|
||||
svc = cli.InMemorySessionService()
|
||||
sess = svc.create_session(app_name="dummy", user_id="u")
|
||||
artifact_service = cli.InMemoryArtifactService()
|
||||
root_agent = types.SimpleNamespace(name="root")
|
||||
"""run_interactively should skip blank input, echo once, then exit."""
|
||||
# make a session that belongs to dummy agent
|
||||
svc = cli.InMemorySessionService()
|
||||
sess = await svc.create_session(app_name="dummy", user_id="u")
|
||||
artifact_service = cli.InMemoryArtifactService()
|
||||
root_agent = types.SimpleNamespace(name="root")
|
||||
|
||||
# fake user input: blank -> 'hello' -> 'exit'
|
||||
answers = iter([" ", "hello", "exit"])
|
||||
monkeypatch.setattr("builtins.input", lambda *_a, **_k: next(answers))
|
||||
# fake user input: blank -> 'hello' -> 'exit'
|
||||
answers = iter([" ", "hello", "exit"])
|
||||
monkeypatch.setattr("builtins.input", lambda *_a, **_k: next(answers))
|
||||
|
||||
# capture assisted echo
|
||||
echoed: list[str] = []
|
||||
monkeypatch.setattr(click, "echo", lambda msg: echoed.append(msg))
|
||||
# capture assisted echo
|
||||
echoed: list[str] = []
|
||||
monkeypatch.setattr(click, "echo", lambda msg: echoed.append(msg))
|
||||
|
||||
await cli.run_interactively(root_agent, artifact_service, sess, svc)
|
||||
await cli.run_interactively(root_agent, artifact_service, sess, svc)
|
||||
|
||||
# verify: assistant echoed once with 'echo:hello'
|
||||
assert any("echo:hello" in m for m in echoed)
|
||||
|
||||
|
||||
# run_cli (resume branch)
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cli_resume_saved_session(tmp_path: Path, fake_agent, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""run_cli should load previous session, print its events, then re-enter interactive mode."""
|
||||
parent_dir, folder = fake_agent
|
||||
|
||||
# stub Session.model_validate_json to return dummy session with two events
|
||||
user_content = types.SimpleNamespace(parts=[types.SimpleNamespace(text="hi")])
|
||||
assistant_content = types.SimpleNamespace(parts=[types.SimpleNamespace(text="hello!")])
|
||||
dummy_session = types.SimpleNamespace(
|
||||
id="sess",
|
||||
app_name=folder,
|
||||
user_id="u",
|
||||
events=[
|
||||
types.SimpleNamespace(author="user", content=user_content, partial=False),
|
||||
types.SimpleNamespace(author="assistant", content=assistant_content, partial=False),
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(cli.Session, "model_validate_json", staticmethod(lambda _s: dummy_session))
|
||||
monkeypatch.setattr(cli.InMemorySessionService, "append_event", lambda *_a, **_k: None)
|
||||
# interactive inputs: immediately 'exit'
|
||||
monkeypatch.setattr("builtins.input", lambda *_a, **_k: "exit")
|
||||
|
||||
# collect echo output
|
||||
captured: list[str] = []
|
||||
monkeypatch.setattr(click, "echo", lambda m: captured.append(m))
|
||||
|
||||
saved_path = tmp_path / "prev.session.json"
|
||||
saved_path.write_text("{}") # contents not used – patched above
|
||||
|
||||
await cli.run_cli(
|
||||
agent_parent_dir=str(parent_dir),
|
||||
agent_folder_name=folder,
|
||||
input_file=None,
|
||||
saved_session_file=str(saved_path),
|
||||
save_session=False,
|
||||
)
|
||||
|
||||
# ④ ensure both historical messages were printed
|
||||
assert any("[user]: hi" in m for m in captured)
|
||||
assert any("[assistant]: hello!" in m for m in captured)
|
||||
# verify: assistant echoed once with 'echo:hello'
|
||||
assert any("echo:hello" in m for m in echoed)
|
||||
|
||||
@@ -31,7 +31,7 @@ async def test_no_examples():
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(model="gemini-1.5-flash", name="agent", examples=[])
|
||||
invocation_context = utils.create_invocation_context(
|
||||
invocation_context = await utils.create_invocation_context(
|
||||
agent=agent, user_content=""
|
||||
)
|
||||
|
||||
@@ -69,7 +69,7 @@ async def test_agent_examples():
|
||||
name="agent",
|
||||
examples=example_list,
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(
|
||||
invocation_context = await utils.create_invocation_context(
|
||||
agent=agent, user_content="test"
|
||||
)
|
||||
|
||||
@@ -122,7 +122,7 @@ async def test_agent_base_example_provider():
|
||||
name="agent",
|
||||
examples=provider,
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(
|
||||
invocation_context = await utils.create_invocation_context(
|
||||
agent=agent, user_content="test"
|
||||
)
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ async def invoke_tool_with_callbacks(
|
||||
before_tool_callback=before_cb,
|
||||
after_tool_callback=after_cb,
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(
|
||||
invocation_context = await utils.create_invocation_context(
|
||||
agent=agent, user_content=""
|
||||
)
|
||||
# Build function call event
|
||||
|
||||
@@ -28,7 +28,7 @@ async def test_no_description():
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(model="gemini-1.5-flash", name="agent")
|
||||
invocation_context = utils.create_invocation_context(agent=agent)
|
||||
invocation_context = await utils.create_invocation_context(agent=agent)
|
||||
|
||||
async for _ in identity.request_processor.run_async(
|
||||
invocation_context,
|
||||
@@ -52,7 +52,7 @@ async def test_with_description():
|
||||
name="agent",
|
||||
description="test description",
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=agent)
|
||||
invocation_context = await utils.create_invocation_context(agent=agent)
|
||||
|
||||
async for _ in identity.request_processor.run_async(
|
||||
invocation_context,
|
||||
|
||||
@@ -36,7 +36,7 @@ async def test_build_system_instruction():
|
||||
{{customer_int }, { non-identifier-float}}, \
|
||||
{'key1': 'value1'} and {{'key2': 'value2'}}."""),
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=agent)
|
||||
invocation_context = await utils.create_invocation_context(agent=agent)
|
||||
invocation_context.session = Session(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
@@ -73,7 +73,7 @@ async def test_function_system_instruction():
|
||||
name="agent",
|
||||
instruction=build_function_instruction,
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=agent)
|
||||
invocation_context = await utils.create_invocation_context(agent=agent)
|
||||
invocation_context.session = Session(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
@@ -111,7 +111,7 @@ async def test_async_function_system_instruction():
|
||||
name="agent",
|
||||
instruction=build_function_instruction,
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=agent)
|
||||
invocation_context = await utils.create_invocation_context(agent=agent)
|
||||
invocation_context.session = Session(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
@@ -147,7 +147,7 @@ async def test_global_system_instruction():
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=sub_agent)
|
||||
invocation_context = await utils.create_invocation_context(agent=sub_agent)
|
||||
invocation_context.session = Session(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
@@ -189,7 +189,7 @@ async def test_function_global_system_instruction():
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=sub_agent)
|
||||
invocation_context = await utils.create_invocation_context(agent=sub_agent)
|
||||
invocation_context.session = Session(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
@@ -231,7 +231,7 @@ async def test_async_function_global_system_instruction():
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=sub_agent)
|
||||
invocation_context = await utils.create_invocation_context(agent=sub_agent)
|
||||
invocation_context.session = Session(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
@@ -263,7 +263,7 @@ async def test_build_system_instruction_with_namespace():
|
||||
"""Use the echo_info tool to echo { customerId }, {app:key}, {user:key}, {a:key}."""
|
||||
),
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=agent)
|
||||
invocation_context = await utils.create_invocation_context(agent=agent)
|
||||
invocation_context.session = Session(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
|
||||
@@ -37,26 +37,28 @@ def get_session_service(
|
||||
return InMemorySessionService()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_get_empty_session(service_type):
|
||||
async def test_get_empty_session(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
assert not session_service.get_session(
|
||||
assert not await session_service.get_session(
|
||||
app_name='my_app', user_id='test_user', session_id='123'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_create_get_session(service_type):
|
||||
async def test_create_get_session(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'test_user'
|
||||
state = {'key': 'value'}
|
||||
|
||||
session = session_service.create_session(
|
||||
session = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=state
|
||||
)
|
||||
assert session.app_name == app_name
|
||||
@@ -64,50 +66,53 @@ def test_create_get_session(service_type):
|
||||
assert session.id
|
||||
assert session.state == state
|
||||
assert (
|
||||
session_service.get_session(
|
||||
await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
== session
|
||||
)
|
||||
|
||||
session_id = session.id
|
||||
session_service.delete_session(
|
||||
await session_service.delete_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
|
||||
assert (
|
||||
not session_service.get_session(
|
||||
await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
== session
|
||||
!= session
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_create_and_list_sessions(service_type):
|
||||
async def test_create_and_list_sessions(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'test_user'
|
||||
|
||||
session_ids = ['session' + str(i) for i in range(5)]
|
||||
for session_id in session_ids:
|
||||
session_service.create_session(
|
||||
await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
|
||||
sessions = session_service.list_sessions(
|
||||
list_sessions_response = await session_service.list_sessions(
|
||||
app_name=app_name, user_id=user_id
|
||||
).sessions
|
||||
)
|
||||
sessions = list_sessions_response.sessions
|
||||
for i in range(len(sessions)):
|
||||
assert sessions[i].id == session_ids[i]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_session_state(service_type):
|
||||
async def test_session_state(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id_1 = 'user1'
|
||||
@@ -118,19 +123,19 @@ def test_session_state(service_type):
|
||||
state_11 = {'key11': 'value11'}
|
||||
state_12 = {'key12': 'value12'}
|
||||
|
||||
session_11 = session_service.create_session(
|
||||
session_11 = await session_service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id_1,
|
||||
state=state_11,
|
||||
session_id=session_id_11,
|
||||
)
|
||||
session_service.create_session(
|
||||
await session_service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id_1,
|
||||
state=state_12,
|
||||
session_id=session_id_12,
|
||||
)
|
||||
session_service.create_session(
|
||||
await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id_2, session_id=session_id_2
|
||||
)
|
||||
|
||||
@@ -149,7 +154,7 @@ def test_session_state(service_type):
|
||||
}
|
||||
),
|
||||
)
|
||||
session_service.append_event(session=session_11, event=event)
|
||||
await session_service.append_event(session=session_11, event=event)
|
||||
|
||||
# User and app state is stored, temp state is filtered.
|
||||
assert session_11.state.get('app:key') == 'value'
|
||||
@@ -157,7 +162,7 @@ def test_session_state(service_type):
|
||||
assert session_11.state.get('user:key1') == 'value1'
|
||||
assert not session_11.state.get('temp:key')
|
||||
|
||||
session_12 = session_service.get_session(
|
||||
session_12 = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id_1, session_id=session_id_12
|
||||
)
|
||||
# After getting a new instance, the session_12 got the user and app state,
|
||||
@@ -166,7 +171,7 @@ def test_session_state(service_type):
|
||||
assert not session_12.state.get('temp:key')
|
||||
|
||||
# The user1's state is not visible to user2, app state is visible
|
||||
session_2 = session_service.get_session(
|
||||
session_2 = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id_2, session_id=session_id_2
|
||||
)
|
||||
assert session_2.state.get('app:key') == 'value'
|
||||
@@ -175,7 +180,7 @@ def test_session_state(service_type):
|
||||
assert not session_2.state.get('user:key1')
|
||||
|
||||
# The change to session_11 is persisted
|
||||
session_11 = session_service.get_session(
|
||||
session_11 = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id_1, session_id=session_id_11
|
||||
)
|
||||
assert session_11.state.get('key11') == 'value11_new'
|
||||
@@ -183,10 +188,11 @@ def test_session_state(service_type):
|
||||
assert not session_11.state.get('temp:key')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_create_new_session_will_merge_states(service_type):
|
||||
async def test_create_new_session_will_merge_states(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
@@ -194,7 +200,7 @@ def test_create_new_session_will_merge_states(service_type):
|
||||
session_id_2 = 'session2'
|
||||
state_1 = {'key1': 'value1'}
|
||||
|
||||
session_1 = session_service.create_session(
|
||||
session_1 = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=state_1, session_id=session_id_1
|
||||
)
|
||||
|
||||
@@ -210,7 +216,7 @@ def test_create_new_session_will_merge_states(service_type):
|
||||
}
|
||||
),
|
||||
)
|
||||
session_service.append_event(session=session_1, event=event)
|
||||
await session_service.append_event(session=session_1, event=event)
|
||||
|
||||
# User and app state is stored, temp state is filtered.
|
||||
assert session_1.state.get('app:key') == 'value'
|
||||
@@ -218,7 +224,7 @@ def test_create_new_session_will_merge_states(service_type):
|
||||
assert session_1.state.get('user:key1') == 'value1'
|
||||
assert not session_1.state.get('temp:key')
|
||||
|
||||
session_2 = session_service.create_session(
|
||||
session_2 = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state={}, session_id=session_id_2
|
||||
)
|
||||
# Session 2 has the persisted states
|
||||
@@ -228,15 +234,18 @@ def test_create_new_session_will_merge_states(service_type):
|
||||
assert not session_2.state.get('temp:key')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_append_event_bytes(service_type):
|
||||
async def test_append_event_bytes(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
session = session_service.create_session(app_name=app_name, user_id=user_id)
|
||||
session = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
event = Event(
|
||||
invocation_id='invocation',
|
||||
author='user',
|
||||
@@ -249,30 +258,34 @@ def test_append_event_bytes(service_type):
|
||||
],
|
||||
),
|
||||
)
|
||||
session_service.append_event(session=session, event=event)
|
||||
await session_service.append_event(session=session, event=event)
|
||||
|
||||
assert session.events[0].content.parts[0] == types.Part.from_bytes(
|
||||
data=b'test_image_data', mime_type='image/png'
|
||||
)
|
||||
|
||||
events = session_service.get_session(
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
).events
|
||||
)
|
||||
events = session.events
|
||||
assert len(events) == 1
|
||||
assert events[0].content.parts[0] == types.Part.from_bytes(
|
||||
data=b'test_image_data', mime_type='image/png'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_append_event_complete(service_type):
|
||||
async def test_append_event_complete(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
session = session_service.create_session(app_name=app_name, user_id=user_id)
|
||||
session = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
event = Event(
|
||||
invocation_id='invocation',
|
||||
author='user',
|
||||
@@ -291,65 +304,73 @@ def test_append_event_complete(service_type):
|
||||
error_message='error_message',
|
||||
interrupted=True,
|
||||
)
|
||||
session_service.append_event(session=session, event=event)
|
||||
await session_service.append_event(session=session, event=event)
|
||||
|
||||
assert (
|
||||
session_service.get_session(
|
||||
await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
== session
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize('service_type', [SessionServiceType.IN_MEMORY])
|
||||
def test_get_session_with_config(service_type):
|
||||
async def test_get_session_with_config(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
num_test_events = 5
|
||||
session = session_service.create_session(app_name=app_name, user_id=user_id)
|
||||
session = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
for i in range(1, num_test_events + 1):
|
||||
event = Event(author='user', timestamp=i)
|
||||
session_service.append_event(session, event)
|
||||
await session_service.append_event(session, event)
|
||||
|
||||
# No config, expect all events to be returned.
|
||||
events = session_service.get_session(
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
).events
|
||||
)
|
||||
events = session.events
|
||||
assert len(events) == num_test_events
|
||||
|
||||
# Only expect the most recent 3 events.
|
||||
num_recent_events = 3
|
||||
config = GetSessionConfig(num_recent_events=num_recent_events)
|
||||
events = session_service.get_session(
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id, config=config
|
||||
).events
|
||||
)
|
||||
events = session.events
|
||||
assert len(events) == num_recent_events
|
||||
assert events[0].timestamp == num_test_events - num_recent_events + 1
|
||||
|
||||
# Only expect events after timestamp 4.0 (inclusive), i.e., 2 events.
|
||||
after_timestamp = 4.0
|
||||
config = GetSessionConfig(after_timestamp=after_timestamp)
|
||||
events = session_service.get_session(
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id, config=config
|
||||
).events
|
||||
)
|
||||
events = session.events
|
||||
assert len(events) == num_test_events - after_timestamp + 1
|
||||
assert events[0].timestamp == after_timestamp
|
||||
|
||||
# Expect no events if none are > after_timestamp.
|
||||
way_after_timestamp = num_test_events * 10
|
||||
config = GetSessionConfig(after_timestamp=way_after_timestamp)
|
||||
events = session_service.get_session(
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id, config=config
|
||||
).events
|
||||
assert len(events) == 0
|
||||
)
|
||||
assert not session.events
|
||||
|
||||
# Both filters applied, i.e., of 3 most recent events, only 2 are after
|
||||
# timestamp 4.0, so expect 2 events.
|
||||
config = GetSessionConfig(
|
||||
after_timestamp=after_timestamp, num_recent_events=num_recent_events
|
||||
)
|
||||
events = session_service.get_session(
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id, config=config
|
||||
).events
|
||||
)
|
||||
events = session.events
|
||||
assert len(events) == num_test_events - after_timestamp + 1
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import re
|
||||
import this
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from google.adk.events import Event
|
||||
from google.adk.events import EventActions
|
||||
@@ -124,7 +124,9 @@ class MockApiClient:
|
||||
this.session_dict: dict[str, Any] = {}
|
||||
this.event_dict: dict[str, list[Any]] = {}
|
||||
|
||||
def request(self, http_method: str, path: str, request_dict: dict[str, Any]):
|
||||
async def async_request(
|
||||
self, http_method: str, path: str, request_dict: dict[str, Any]
|
||||
):
|
||||
"""Mocks the API Client request method."""
|
||||
if http_method == 'GET':
|
||||
if re.match(SESSION_REGEX, path):
|
||||
@@ -210,46 +212,52 @@ def mock_vertex_ai_session_service():
|
||||
return service
|
||||
|
||||
|
||||
def test_get_empty_session():
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_empty_session():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
assert session_service.get_session(
|
||||
assert await session_service.get_session(
|
||||
app_name='123', user_id='user', session_id='0'
|
||||
)
|
||||
assert str(excinfo.value) == 'Session not found: 0'
|
||||
|
||||
|
||||
def test_get_and_delete_session():
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_and_delete_session():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
|
||||
assert (
|
||||
session_service.get_session(
|
||||
await session_service.get_session(
|
||||
app_name='123', user_id='user', session_id='1'
|
||||
)
|
||||
== MOCK_SESSION
|
||||
)
|
||||
|
||||
session_service.delete_session(app_name='123', user_id='user', session_id='1')
|
||||
await session_service.delete_session(
|
||||
app_name='123', user_id='user', session_id='1'
|
||||
)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
assert session_service.get_session(
|
||||
assert await session_service.get_session(
|
||||
app_name='123', user_id='user', session_id='1'
|
||||
)
|
||||
assert str(excinfo.value) == 'Session not found: 1'
|
||||
|
||||
|
||||
def test_list_sessions():
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
sessions = session_service.list_sessions(app_name='123', user_id='user')
|
||||
sessions = await session_service.list_sessions(app_name='123', user_id='user')
|
||||
assert len(sessions.sessions) == 2
|
||||
assert sessions.sessions[0].id == '1'
|
||||
assert sessions.sessions[1].id == '2'
|
||||
|
||||
|
||||
def test_create_session():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
|
||||
state = {'key': 'value'}
|
||||
session = session_service.create_session(
|
||||
session = await session_service.create_session(
|
||||
app_name='123', user_id='user', state=state
|
||||
)
|
||||
assert session.state == state
|
||||
@@ -258,16 +266,17 @@ def test_create_session():
|
||||
assert session.last_update_time is not None
|
||||
|
||||
session_id = session.id
|
||||
assert session == session_service.get_session(
|
||||
assert session == await session_service.get_session(
|
||||
app_name='123', user_id='user', session_id=session_id
|
||||
)
|
||||
|
||||
|
||||
def test_create_session_with_custom_session_id():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_with_custom_session_id():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
session_service.create_session(
|
||||
await session_service.create_session(
|
||||
app_name='123', user_id='user', session_id='1'
|
||||
)
|
||||
assert str(excinfo.value) == (
|
||||
|
||||
@@ -37,9 +37,9 @@ class _TestingTool(BaseTool):
|
||||
return self.declaration
|
||||
|
||||
|
||||
def _create_tool_context() -> ToolContext:
|
||||
async def _create_tool_context() -> ToolContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
session = await session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
agent = SequentialAgent(name='test_agent')
|
||||
@@ -55,7 +55,7 @@ def _create_tool_context() -> ToolContext:
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_llm_request_no_declaration():
|
||||
tool = _TestingTool()
|
||||
tool_context = _create_tool_context()
|
||||
tool_context = await _create_tool_context()
|
||||
llm_request = LlmRequest()
|
||||
|
||||
await tool.process_llm_request(
|
||||
@@ -77,7 +77,7 @@ async def test_process_llm_request_with_declaration():
|
||||
)
|
||||
tool = _TestingTool(declaration)
|
||||
llm_request = LlmRequest()
|
||||
tool_context = _create_tool_context()
|
||||
tool_context = await _create_tool_context()
|
||||
|
||||
await tool.process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
@@ -102,7 +102,7 @@ async def test_process_llm_request_with_builtin_tool():
|
||||
tools=[types.Tool(google_search=types.GoogleSearch())]
|
||||
)
|
||||
)
|
||||
tool_context = _create_tool_context()
|
||||
tool_context = await _create_tool_context()
|
||||
|
||||
await tool.process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
@@ -131,7 +131,7 @@ async def test_process_llm_request_with_builtin_tool_and_another_declaration():
|
||||
]
|
||||
)
|
||||
)
|
||||
tool_context = _create_tool_context()
|
||||
tool_context = await _create_tool_context()
|
||||
|
||||
await tool.process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
|
||||
@@ -56,7 +56,7 @@ class ModelContent(types.Content):
|
||||
super().__init__(role='model', parts=parts)
|
||||
|
||||
|
||||
def create_invocation_context(agent: Agent, user_content: str = ''):
|
||||
async def create_invocation_context(agent: Agent, user_content: str = ''):
|
||||
invocation_id = 'test_id'
|
||||
artifact_service = InMemoryArtifactService()
|
||||
session_service = InMemorySessionService()
|
||||
@@ -67,7 +67,7 @@ def create_invocation_context(agent: Agent, user_content: str = ''):
|
||||
memory_service=memory_service,
|
||||
invocation_id=invocation_id,
|
||||
agent=agent,
|
||||
session=session_service.create_session(
|
||||
session=await session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
),
|
||||
user_content=types.Content(
|
||||
@@ -141,7 +141,7 @@ class TestInMemoryRunner(AfInMemoryRunner):
|
||||
self, new_message: types.ContentUnion
|
||||
) -> list[Event]:
|
||||
|
||||
session = self.session_service.create_session(
|
||||
session = await self.session_service.create_session(
|
||||
app_name='InMemoryRunner', user_id='test_user'
|
||||
)
|
||||
collected_events = []
|
||||
@@ -172,14 +172,22 @@ class InMemoryRunner:
|
||||
session_service=InMemorySessionService(),
|
||||
memory_service=InMemoryMemoryService(),
|
||||
)
|
||||
self.session_id = self.runner.session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
).id
|
||||
self.session_id = None
|
||||
|
||||
@property
|
||||
def session(self) -> Session:
|
||||
return self.runner.session_service.get_session(
|
||||
app_name='test_app', user_id='test_user', session_id=self.session_id
|
||||
if not self.session_id:
|
||||
session = asyncio.run(
|
||||
self.runner.session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
)
|
||||
self.session_id = session.id
|
||||
return session
|
||||
return asyncio.run(
|
||||
self.runner.session_service.get_session(
|
||||
app_name='test_app', user_id='test_user', session_id=self.session_id
|
||||
)
|
||||
)
|
||||
|
||||
def run(self, new_message: types.ContentUnion) -> list[Event]:
|
||||
@@ -194,9 +202,9 @@ class InMemoryRunner:
|
||||
def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]:
|
||||
collected_responses = []
|
||||
|
||||
async def consume_responses():
|
||||
async def consume_responses(session: Session):
|
||||
run_res = self.runner.run_live(
|
||||
session=self.session,
|
||||
session=session,
|
||||
live_request_queue=live_request_queue,
|
||||
)
|
||||
|
||||
@@ -207,7 +215,8 @@ class InMemoryRunner:
|
||||
return
|
||||
|
||||
try:
|
||||
asyncio.run(consume_responses())
|
||||
session = self.session
|
||||
asyncio.run(consume_responses(session))
|
||||
except asyncio.TimeoutError:
|
||||
print('Returning any partial results collected so far.')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user