mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 07:04:51 -06:00
fix: ParallelAgent should only append to its immediate sub-agent, not transitive descendants
Restores automatic conversation history sharing for sequential/loop sub-agents. PiperOrigin-RevId: 766742380
This commit is contained in:
parent
57d99aa789
commit
ec8bc7387c
@ -246,8 +246,6 @@ class BaseAgent(BaseModel):
|
||||
) -> InvocationContext:
|
||||
"""Creates a new invocation context for this agent."""
|
||||
invocation_context = parent_context.model_copy(update={'agent': self})
|
||||
if parent_context.branch:
|
||||
invocation_context.branch = f'{parent_context.branch}.{self.name}'
|
||||
return invocation_context
|
||||
|
||||
@property
|
||||
|
@ -26,14 +26,20 @@ from ..events.event import Event
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
|
||||
def _set_branch_for_current_agent(
|
||||
current_agent: BaseAgent, invocation_context: InvocationContext
|
||||
):
|
||||
def _create_branch_ctx_for_sub_agent(
|
||||
agent: BaseAgent,
|
||||
sub_agent: BaseAgent,
|
||||
invocation_context: InvocationContext,
|
||||
) -> InvocationContext:
|
||||
"""Create isolated branch for every sub-agent."""
|
||||
invocation_context = invocation_context.model_copy()
|
||||
branch_suffix = f"{agent.name}.{sub_agent.name}"
|
||||
invocation_context.branch = (
|
||||
f"{invocation_context.branch}.{current_agent.name}"
|
||||
f"{invocation_context.branch}.{branch_suffix}"
|
||||
if invocation_context.branch
|
||||
else current_agent.name
|
||||
else branch_suffix
|
||||
)
|
||||
return invocation_context
|
||||
|
||||
|
||||
async def _merge_agent_run(
|
||||
@ -90,8 +96,12 @@ class ParallelAgent(BaseAgent):
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
_set_branch_for_current_agent(self, ctx)
|
||||
agent_runs = [agent.run_async(ctx) for agent in self.sub_agents]
|
||||
agent_runs = [
|
||||
sub_agent.run_async(
|
||||
_create_branch_ctx_for_sub_agent(self, sub_agent, ctx)
|
||||
)
|
||||
for sub_agent in self.sub_agents
|
||||
]
|
||||
async for event in _merge_agent_run(agent_runs):
|
||||
yield event
|
||||
|
||||
|
@ -159,7 +159,7 @@ async def test_run_async_with_branch(request: pytest.FixtureRequest):
|
||||
assert len(events) == 1
|
||||
assert events[0].author == agent.name
|
||||
assert events[0].content.parts[0].text == 'Hello, world!'
|
||||
assert events[0].branch.endswith(agent.name)
|
||||
assert events[0].branch == 'parent_branch'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -625,7 +625,7 @@ async def test_run_live_with_branch(request: pytest.FixtureRequest):
|
||||
assert len(events) == 1
|
||||
assert events[0].author == agent.name
|
||||
assert events[0].content.parts[0].text == 'Hello, live!'
|
||||
assert events[0].branch.endswith(agent.name)
|
||||
assert events[0].branch == 'parent_branch'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -20,6 +20,7 @@ from typing import AsyncGenerator
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.parallel_agent import ParallelAgent
|
||||
from google.adk.agents.sequential_agent import SequentialAgent
|
||||
from google.adk.events import Event
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.genai import types
|
||||
@ -86,7 +87,51 @@ async def test_run_async(request: pytest.FixtureRequest):
|
||||
# and agent1 has a delay.
|
||||
assert events[0].author == agent2.name
|
||||
assert events[1].author == agent1.name
|
||||
assert events[0].branch.endswith(agent2.name)
|
||||
assert events[1].branch.endswith(agent1.name)
|
||||
assert events[0].branch.endswith(f'{parallel_agent.name}.{agent2.name}')
|
||||
assert events[1].branch.endswith(f'{parallel_agent.name}.{agent1.name}')
|
||||
assert events[0].content.parts[0].text == f'Hello, async {agent2.name}!'
|
||||
assert events[1].content.parts[0].text == f'Hello, async {agent1.name}!'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_branches(request: pytest.FixtureRequest):
|
||||
agent1 = _TestingAgent(
|
||||
name=f'{request.function.__name__}_test_agent_1',
|
||||
delay=0.5,
|
||||
)
|
||||
agent2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2')
|
||||
agent3 = _TestingAgent(name=f'{request.function.__name__}_test_agent_3')
|
||||
sequential_agent = SequentialAgent(
|
||||
name=f'{request.function.__name__}_test_sequential_agent',
|
||||
sub_agents=[agent2, agent3],
|
||||
)
|
||||
parallel_agent = ParallelAgent(
|
||||
name=f'{request.function.__name__}_test_parallel_agent',
|
||||
sub_agents=[
|
||||
sequential_agent,
|
||||
agent1,
|
||||
],
|
||||
)
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, parallel_agent
|
||||
)
|
||||
events = [e async for e in parallel_agent.run_async(parent_ctx)]
|
||||
|
||||
assert len(events) == 3
|
||||
assert (
|
||||
events[0].author == agent2.name
|
||||
and events[0].branch == f'{parallel_agent.name}.{sequential_agent.name}'
|
||||
)
|
||||
assert (
|
||||
events[1].author == agent3.name
|
||||
and events[0].branch == f'{parallel_agent.name}.{sequential_agent.name}'
|
||||
)
|
||||
# Descendants of the same sub-agent should have the same branch.
|
||||
assert events[0].branch == events[1].branch
|
||||
assert (
|
||||
events[2].author == agent1.name
|
||||
and events[2].branch == f'{parallel_agent.name}.{agent1.name}'
|
||||
)
|
||||
# Sub-agents should have different branches.
|
||||
assert events[2].branch != events[1].branch
|
||||
assert events[2].branch != events[0].branch
|
||||
|
Loading…
Reference in New Issue
Block a user