mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 11:22:22 -06:00
ADK changes
PiperOrigin-RevId: 759259620
This commit is contained in:
committed by
Copybara-Service
parent
1804ca39a6
commit
05917cabbd
@@ -110,11 +110,11 @@ class _TestingAgent(BaseAgent):
|
||||
)
|
||||
|
||||
|
||||
def _create_parent_invocation_context(
|
||||
async def _create_parent_invocation_context(
|
||||
test_name: str, agent: BaseAgent, branch: Optional[str] = None
|
||||
) -> InvocationContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
session = await session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
return InvocationContext(
|
||||
@@ -134,7 +134,7 @@ def test_invalid_agent_name():
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async(request: pytest.FixtureRequest):
|
||||
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
|
||||
@@ -148,7 +148,7 @@ async def test_run_async(request: pytest.FixtureRequest):
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_branch(request: pytest.FixtureRequest):
|
||||
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent, branch='parent_branch'
|
||||
)
|
||||
|
||||
@@ -170,7 +170,7 @@ async def test_run_async_before_agent_callback_noop(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
before_agent_callback=_before_agent_callback_noop,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
|
||||
@@ -198,7 +198,7 @@ async def test_run_async_with_async_before_agent_callback_noop(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
before_agent_callback=_async_before_agent_callback_noop,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
|
||||
@@ -226,7 +226,7 @@ async def test_run_async_before_agent_callback_bypass_agent(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
before_agent_callback=_before_agent_callback_bypass_agent,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
|
||||
@@ -253,7 +253,7 @@ async def test_run_async_with_async_before_agent_callback_bypass_agent(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
before_agent_callback=_async_before_agent_callback_bypass_agent,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
|
||||
@@ -394,7 +394,7 @@ async def test_before_agent_callbacks_chain(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
before_agent_callback=[mock_cb for mock_cb in mock_cbs],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
result = [e async for e in agent.run_async(parent_ctx)]
|
||||
@@ -455,7 +455,7 @@ async def test_after_agent_callbacks_chain(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
after_agent_callback=[mock_cb for mock_cb in mock_cbs],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
result = [e async for e in agent.run_async(parent_ctx)]
|
||||
@@ -494,7 +494,7 @@ async def test_run_async_after_agent_callback_noop(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
after_agent_callback=_after_agent_callback_noop,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
|
||||
@@ -520,7 +520,7 @@ async def test_run_async_with_async_after_agent_callback_noop(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
after_agent_callback=_async_after_agent_callback_noop,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
|
||||
@@ -545,7 +545,7 @@ async def test_run_async_after_agent_callback_append_reply(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
after_agent_callback=_after_agent_callback_append_agent_reply,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
|
||||
@@ -570,7 +570,7 @@ async def test_run_async_with_async_after_agent_callback_append_reply(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
after_agent_callback=_async_after_agent_callback_append_agent_reply,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
|
||||
@@ -589,7 +589,7 @@ async def test_run_async_with_async_after_agent_callback_append_reply(
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
|
||||
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
|
||||
@@ -600,7 +600,7 @@ async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_live(request: pytest.FixtureRequest):
|
||||
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
|
||||
@@ -614,7 +614,7 @@ async def test_run_live(request: pytest.FixtureRequest):
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_live_with_branch(request: pytest.FixtureRequest):
|
||||
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent, branch='parent_branch'
|
||||
)
|
||||
|
||||
@@ -629,7 +629,7 @@ async def test_run_live_with_branch(request: pytest.FixtureRequest):
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_live_incomplete_agent(request: pytest.FixtureRequest):
|
||||
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"""Unit tests for canonical_xxx fields in LlmAgent."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
@@ -30,11 +30,11 @@ from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
|
||||
def _create_readonly_context(
|
||||
async def _create_readonly_context(
|
||||
agent: LlmAgent, state: Optional[dict[str, Any]] = None
|
||||
) -> ReadonlyContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
session = await session_service.create_session(
|
||||
app_name='test_app', user_id='test_user', state=state
|
||||
)
|
||||
invocation_context = InvocationContext(
|
||||
@@ -77,7 +77,7 @@ def test_canonical_model_inherit():
|
||||
|
||||
async def test_canonical_instruction_str():
|
||||
agent = LlmAgent(name='test_agent', instruction='instruction')
|
||||
ctx = _create_readonly_context(agent)
|
||||
ctx = await _create_readonly_context(agent)
|
||||
|
||||
canonical_instruction = await agent.canonical_instruction(ctx)
|
||||
assert canonical_instruction == 'instruction'
|
||||
@@ -88,7 +88,9 @@ async def test_canonical_instruction():
|
||||
return f'instruction: {ctx.state["state_var"]}'
|
||||
|
||||
agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
|
||||
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
|
||||
ctx = await _create_readonly_context(
|
||||
agent, state={'state_var': 'state_value'}
|
||||
)
|
||||
|
||||
canonical_instruction = await agent.canonical_instruction(ctx)
|
||||
assert canonical_instruction == 'instruction: state_value'
|
||||
@@ -99,7 +101,9 @@ async def test_async_canonical_instruction():
|
||||
return f'instruction: {ctx.state["state_var"]}'
|
||||
|
||||
agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
|
||||
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
|
||||
ctx = await _create_readonly_context(
|
||||
agent, state={'state_var': 'state_value'}
|
||||
)
|
||||
|
||||
canonical_instruction = await agent.canonical_instruction(ctx)
|
||||
assert canonical_instruction == 'instruction: state_value'
|
||||
@@ -107,10 +111,10 @@ async def test_async_canonical_instruction():
|
||||
|
||||
async def test_canonical_global_instruction_str():
|
||||
agent = LlmAgent(name='test_agent', global_instruction='global instruction')
|
||||
ctx = _create_readonly_context(agent)
|
||||
ctx = await _create_readonly_context(agent)
|
||||
|
||||
canonical_global_instruction = await agent.canonical_global_instruction(ctx)
|
||||
assert canonical_global_instruction == 'global instruction'
|
||||
canonical_instruction = await agent.canonical_global_instruction(ctx)
|
||||
assert canonical_instruction == 'global instruction'
|
||||
|
||||
|
||||
async def test_canonical_global_instruction():
|
||||
@@ -120,7 +124,9 @@ async def test_canonical_global_instruction():
|
||||
agent = LlmAgent(
|
||||
name='test_agent', global_instruction=_global_instruction_provider
|
||||
)
|
||||
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
|
||||
ctx = await _create_readonly_context(
|
||||
agent, state={'state_var': 'state_value'}
|
||||
)
|
||||
|
||||
canonical_global_instruction = await agent.canonical_global_instruction(ctx)
|
||||
assert canonical_global_instruction == 'global instruction: state_value'
|
||||
@@ -133,10 +139,14 @@ async def test_async_canonical_global_instruction():
|
||||
agent = LlmAgent(
|
||||
name='test_agent', global_instruction=_global_instruction_provider
|
||||
)
|
||||
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
|
||||
ctx = await _create_readonly_context(
|
||||
agent, state={'state_var': 'state_value'}
|
||||
)
|
||||
|
||||
canonical_global_instruction = await agent.canonical_global_instruction(ctx)
|
||||
assert canonical_global_instruction == 'global instruction: state_value'
|
||||
assert (
|
||||
await agent.canonical_global_instruction(ctx)
|
||||
== 'global instruction: state_value'
|
||||
)
|
||||
|
||||
|
||||
def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture):
|
||||
|
||||
@@ -70,11 +70,11 @@ class _TestingAgentWithEscalateAction(BaseAgent):
|
||||
)
|
||||
|
||||
|
||||
def _create_parent_invocation_context(
|
||||
async def _create_parent_invocation_context(
|
||||
test_name: str, agent: BaseAgent
|
||||
) -> InvocationContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
session = await session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
return InvocationContext(
|
||||
@@ -95,7 +95,7 @@ async def test_run_async(request: pytest.FixtureRequest):
|
||||
agent,
|
||||
],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, loop_agent
|
||||
)
|
||||
events = [e async for e in loop_agent.run_async(parent_ctx)]
|
||||
@@ -119,7 +119,7 @@ async def test_run_async_with_escalate_action(request: pytest.FixtureRequest):
|
||||
name=f'{request.function.__name__}_test_loop_agent',
|
||||
sub_agents=[non_escalating_agent, escalating_agent],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, loop_agent
|
||||
)
|
||||
events = [e async for e in loop_agent.run_async(parent_ctx)]
|
||||
|
||||
@@ -47,11 +47,11 @@ class _TestingAgent(BaseAgent):
|
||||
)
|
||||
|
||||
|
||||
def _create_parent_invocation_context(
|
||||
async def _create_parent_invocation_context(
|
||||
test_name: str, agent: BaseAgent
|
||||
) -> InvocationContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
session = await session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
return InvocationContext(
|
||||
@@ -76,7 +76,7 @@ async def test_run_async(request: pytest.FixtureRequest):
|
||||
agent2,
|
||||
],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, parallel_agent
|
||||
)
|
||||
events = [e async for e in parallel_agent.run_async(parent_ctx)]
|
||||
|
||||
@@ -53,11 +53,11 @@ class _TestingAgent(BaseAgent):
|
||||
)
|
||||
|
||||
|
||||
def _create_parent_invocation_context(
|
||||
async def _create_parent_invocation_context(
|
||||
test_name: str, agent: BaseAgent
|
||||
) -> InvocationContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
session = await session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
return InvocationContext(
|
||||
@@ -79,7 +79,7 @@ async def test_run_async(request: pytest.FixtureRequest):
|
||||
agent_2,
|
||||
],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, sequential_agent
|
||||
)
|
||||
events = [e async for e in sequential_agent.run_async(parent_ctx)]
|
||||
@@ -102,7 +102,7 @@ async def test_run_live(request: pytest.FixtureRequest):
|
||||
agent_2,
|
||||
],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, sequential_agent
|
||||
)
|
||||
events = [e async for e in sequential_agent.run_live(parent_ctx)]
|
||||
|
||||
Reference in New Issue
Block a user