ADK changes

PiperOrigin-RevId: 759259620
This commit is contained in:
Google Team Member
2025-05-15 12:46:12 -07:00
committed by Copybara-Service
parent 1804ca39a6
commit 05917cabbd
23 changed files with 264 additions and 268 deletions

View File

@@ -110,11 +110,11 @@ class _TestingAgent(BaseAgent):
)
def _create_parent_invocation_context(
async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent, branch: Optional[str] = None
) -> InvocationContext:
session_service = InMemorySessionService()
session = session_service.create_session(
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
return InvocationContext(
@@ -134,7 +134,7 @@ def test_invalid_agent_name():
@pytest.mark.asyncio
async def test_run_async(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
@@ -148,7 +148,7 @@ async def test_run_async(request: pytest.FixtureRequest):
@pytest.mark.asyncio
async def test_run_async_with_branch(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent, branch='parent_branch'
)
@@ -170,7 +170,7 @@ async def test_run_async_before_agent_callback_noop(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_before_agent_callback_noop,
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
@@ -198,7 +198,7 @@ async def test_run_async_with_async_before_agent_callback_noop(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_async_before_agent_callback_noop,
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
@@ -226,7 +226,7 @@ async def test_run_async_before_agent_callback_bypass_agent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_before_agent_callback_bypass_agent,
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
@@ -253,7 +253,7 @@ async def test_run_async_with_async_before_agent_callback_bypass_agent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_async_before_agent_callback_bypass_agent,
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
@@ -394,7 +394,7 @@ async def test_before_agent_callbacks_chain(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=[mock_cb for mock_cb in mock_cbs],
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
result = [e async for e in agent.run_async(parent_ctx)]
@@ -455,7 +455,7 @@ async def test_after_agent_callbacks_chain(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=[mock_cb for mock_cb in mock_cbs],
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
result = [e async for e in agent.run_async(parent_ctx)]
@@ -494,7 +494,7 @@ async def test_run_async_after_agent_callback_noop(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_after_agent_callback_noop,
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
@@ -520,7 +520,7 @@ async def test_run_async_with_async_after_agent_callback_noop(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_async_after_agent_callback_noop,
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
@@ -545,7 +545,7 @@ async def test_run_async_after_agent_callback_append_reply(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_after_agent_callback_append_agent_reply,
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
@@ -570,7 +570,7 @@ async def test_run_async_with_async_after_agent_callback_append_reply(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_async_after_agent_callback_append_agent_reply,
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
@@ -589,7 +589,7 @@ async def test_run_async_with_async_after_agent_callback_append_reply(
@pytest.mark.asyncio
async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
@@ -600,7 +600,7 @@ async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
@pytest.mark.asyncio
async def test_run_live(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
@@ -614,7 +614,7 @@ async def test_run_live(request: pytest.FixtureRequest):
@pytest.mark.asyncio
async def test_run_live_with_branch(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent, branch='parent_branch'
)
@@ -629,7 +629,7 @@ async def test_run_live_with_branch(request: pytest.FixtureRequest):
@pytest.mark.asyncio
async def test_run_live_incomplete_agent(request: pytest.FixtureRequest):
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)

View File

@@ -15,7 +15,7 @@
"""Unit tests for canonical_xxx fields in LlmAgent."""
from typing import Any
from typing import Optional
from typing import Optional, cast
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext
@@ -30,11 +30,11 @@ from pydantic import BaseModel
import pytest
def _create_readonly_context(
async def _create_readonly_context(
agent: LlmAgent, state: Optional[dict[str, Any]] = None
) -> ReadonlyContext:
session_service = InMemorySessionService()
session = session_service.create_session(
session = await session_service.create_session(
app_name='test_app', user_id='test_user', state=state
)
invocation_context = InvocationContext(
@@ -77,7 +77,7 @@ def test_canonical_model_inherit():
async def test_canonical_instruction_str():
agent = LlmAgent(name='test_agent', instruction='instruction')
ctx = _create_readonly_context(agent)
ctx = await _create_readonly_context(agent)
canonical_instruction = await agent.canonical_instruction(ctx)
assert canonical_instruction == 'instruction'
@@ -88,7 +88,9 @@ async def test_canonical_instruction():
return f'instruction: {ctx.state["state_var"]}'
agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
ctx = await _create_readonly_context(
agent, state={'state_var': 'state_value'}
)
canonical_instruction = await agent.canonical_instruction(ctx)
assert canonical_instruction == 'instruction: state_value'
@@ -99,7 +101,9 @@ async def test_async_canonical_instruction():
return f'instruction: {ctx.state["state_var"]}'
agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
ctx = await _create_readonly_context(
agent, state={'state_var': 'state_value'}
)
canonical_instruction = await agent.canonical_instruction(ctx)
assert canonical_instruction == 'instruction: state_value'
@@ -107,10 +111,10 @@ async def test_async_canonical_instruction():
async def test_canonical_global_instruction_str():
agent = LlmAgent(name='test_agent', global_instruction='global instruction')
ctx = _create_readonly_context(agent)
ctx = await _create_readonly_context(agent)
canonical_global_instruction = await agent.canonical_global_instruction(ctx)
assert canonical_global_instruction == 'global instruction'
canonical_instruction = await agent.canonical_global_instruction(ctx)
assert canonical_instruction == 'global instruction'
async def test_canonical_global_instruction():
@@ -120,7 +124,9 @@ async def test_canonical_global_instruction():
agent = LlmAgent(
name='test_agent', global_instruction=_global_instruction_provider
)
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
ctx = await _create_readonly_context(
agent, state={'state_var': 'state_value'}
)
canonical_global_instruction = await agent.canonical_global_instruction(ctx)
assert canonical_global_instruction == 'global instruction: state_value'
@@ -133,10 +139,14 @@ async def test_async_canonical_global_instruction():
agent = LlmAgent(
name='test_agent', global_instruction=_global_instruction_provider
)
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
ctx = await _create_readonly_context(
agent, state={'state_var': 'state_value'}
)
canonical_global_instruction = await agent.canonical_global_instruction(ctx)
assert canonical_global_instruction == 'global instruction: state_value'
assert (
await agent.canonical_global_instruction(ctx)
== 'global instruction: state_value'
)
def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture):

View File

@@ -70,11 +70,11 @@ class _TestingAgentWithEscalateAction(BaseAgent):
)
def _create_parent_invocation_context(
async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent
) -> InvocationContext:
session_service = InMemorySessionService()
session = session_service.create_session(
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
return InvocationContext(
@@ -95,7 +95,7 @@ async def test_run_async(request: pytest.FixtureRequest):
agent,
],
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, loop_agent
)
events = [e async for e in loop_agent.run_async(parent_ctx)]
@@ -119,7 +119,7 @@ async def test_run_async_with_escalate_action(request: pytest.FixtureRequest):
name=f'{request.function.__name__}_test_loop_agent',
sub_agents=[non_escalating_agent, escalating_agent],
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, loop_agent
)
events = [e async for e in loop_agent.run_async(parent_ctx)]

View File

@@ -47,11 +47,11 @@ class _TestingAgent(BaseAgent):
)
def _create_parent_invocation_context(
async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent
) -> InvocationContext:
session_service = InMemorySessionService()
session = session_service.create_session(
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
return InvocationContext(
@@ -76,7 +76,7 @@ async def test_run_async(request: pytest.FixtureRequest):
agent2,
],
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, parallel_agent
)
events = [e async for e in parallel_agent.run_async(parent_ctx)]

View File

@@ -53,11 +53,11 @@ class _TestingAgent(BaseAgent):
)
def _create_parent_invocation_context(
async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent
) -> InvocationContext:
session_service = InMemorySessionService()
session = session_service.create_session(
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
return InvocationContext(
@@ -79,7 +79,7 @@ async def test_run_async(request: pytest.FixtureRequest):
agent_2,
],
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, sequential_agent
)
events = [e async for e in sequential_agent.run_async(parent_ctx)]
@@ -102,7 +102,7 @@ async def test_run_live(request: pytest.FixtureRequest):
agent_2,
],
)
parent_ctx = _create_parent_invocation_context(
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, sequential_agent
)
events = [e async for e in sequential_agent.run_live(parent_ctx)]