fix: ParallelAgent should only append to its immediate sub-agent, not transitive descendants

Restores automatic conversation history sharing for sequential/loop sub-agents.

PiperOrigin-RevId: 766742380
This commit is contained in:
Google Team Member 2025-06-03 11:08:41 -07:00 committed by Copybara-Service
parent 57d99aa789
commit ec8bc7387c
4 changed files with 66 additions and 13 deletions

View File

@ -246,8 +246,6 @@ class BaseAgent(BaseModel):
) -> InvocationContext:
"""Creates a new invocation context for this agent."""
invocation_context = parent_context.model_copy(update={'agent': self})
if parent_context.branch:
invocation_context.branch = f'{parent_context.branch}.{self.name}'
return invocation_context
@property

View File

@ -26,14 +26,20 @@ from ..events.event import Event
from .base_agent import BaseAgent
def _set_branch_for_current_agent(
current_agent: BaseAgent, invocation_context: InvocationContext
):
def _create_branch_ctx_for_sub_agent(
agent: BaseAgent,
sub_agent: BaseAgent,
invocation_context: InvocationContext,
) -> InvocationContext:
"""Create isolated branch for every sub-agent."""
invocation_context = invocation_context.model_copy()
branch_suffix = f"{agent.name}.{sub_agent.name}"
invocation_context.branch = (
f"{invocation_context.branch}.{current_agent.name}"
f"{invocation_context.branch}.{branch_suffix}"
if invocation_context.branch
else current_agent.name
else branch_suffix
)
return invocation_context
async def _merge_agent_run(
@ -90,8 +96,12 @@ class ParallelAgent(BaseAgent):
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
_set_branch_for_current_agent(self, ctx)
agent_runs = [agent.run_async(ctx) for agent in self.sub_agents]
agent_runs = [
sub_agent.run_async(
_create_branch_ctx_for_sub_agent(self, sub_agent, ctx)
)
for sub_agent in self.sub_agents
]
async for event in _merge_agent_run(agent_runs):
yield event

View File

@ -159,7 +159,7 @@ async def test_run_async_with_branch(request: pytest.FixtureRequest):
assert len(events) == 1
assert events[0].author == agent.name
assert events[0].content.parts[0].text == 'Hello, world!'
assert events[0].branch.endswith(agent.name)
assert events[0].branch == 'parent_branch'
@pytest.mark.asyncio
@ -625,7 +625,7 @@ async def test_run_live_with_branch(request: pytest.FixtureRequest):
assert len(events) == 1
assert events[0].author == agent.name
assert events[0].content.parts[0].text == 'Hello, live!'
assert events[0].branch.endswith(agent.name)
assert events[0].branch == 'parent_branch'
@pytest.mark.asyncio

View File

@ -20,6 +20,7 @@ from typing import AsyncGenerator
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.parallel_agent import ParallelAgent
from google.adk.agents.sequential_agent import SequentialAgent
from google.adk.events import Event
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.genai import types
@ -86,7 +87,51 @@ async def test_run_async(request: pytest.FixtureRequest):
# and agent1 has a delay.
assert events[0].author == agent2.name
assert events[1].author == agent1.name
assert events[0].branch.endswith(agent2.name)
assert events[1].branch.endswith(agent1.name)
assert events[0].branch.endswith(f'{parallel_agent.name}.{agent2.name}')
assert events[1].branch.endswith(f'{parallel_agent.name}.{agent1.name}')
assert events[0].content.parts[0].text == f'Hello, async {agent2.name}!'
assert events[1].content.parts[0].text == f'Hello, async {agent1.name}!'
@pytest.mark.asyncio
async def test_run_async_branches(request: pytest.FixtureRequest):
agent1 = _TestingAgent(
name=f'{request.function.__name__}_test_agent_1',
delay=0.5,
)
agent2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2')
agent3 = _TestingAgent(name=f'{request.function.__name__}_test_agent_3')
sequential_agent = SequentialAgent(
name=f'{request.function.__name__}_test_sequential_agent',
sub_agents=[agent2, agent3],
)
parallel_agent = ParallelAgent(
name=f'{request.function.__name__}_test_parallel_agent',
sub_agents=[
sequential_agent,
agent1,
],
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, parallel_agent
)
events = [e async for e in parallel_agent.run_async(parent_ctx)]
assert len(events) == 3
assert (
events[0].author == agent2.name
and events[0].branch == f'{parallel_agent.name}.{sequential_agent.name}'
)
assert (
events[1].author == agent3.name
and events[0].branch == f'{parallel_agent.name}.{sequential_agent.name}'
)
# Descendants of the same sub-agent should have the same branch.
assert events[0].branch == events[1].branch
assert (
events[2].author == agent1.name
and events[2].branch == f'{parallel_agent.name}.{agent1.name}'
)
# Sub-agents should have different branches.
assert events[2].branch != events[1].branch
assert events[2].branch != events[0].branch