diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 18a5de4..bdc10ac 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -246,8 +246,6 @@ class BaseAgent(BaseModel): ) -> InvocationContext: """Creates a new invocation context for this agent.""" invocation_context = parent_context.model_copy(update={'agent': self}) - if parent_context.branch: - invocation_context.branch = f'{parent_context.branch}.{self.name}' return invocation_context @property diff --git a/src/google/adk/agents/parallel_agent.py b/src/google/adk/agents/parallel_agent.py index 61ca41b..427128c 100644 --- a/src/google/adk/agents/parallel_agent.py +++ b/src/google/adk/agents/parallel_agent.py @@ -26,14 +26,20 @@ from ..events.event import Event from .base_agent import BaseAgent -def _set_branch_for_current_agent( - current_agent: BaseAgent, invocation_context: InvocationContext -): +def _create_branch_ctx_for_sub_agent( + agent: BaseAgent, + sub_agent: BaseAgent, + invocation_context: InvocationContext, +) -> InvocationContext: + """Create isolated branch for every sub-agent.""" + invocation_context = invocation_context.model_copy() + branch_suffix = f"{agent.name}.{sub_agent.name}" invocation_context.branch = ( - f"{invocation_context.branch}.{current_agent.name}" + f"{invocation_context.branch}.{branch_suffix}" if invocation_context.branch - else current_agent.name + else branch_suffix ) + return invocation_context async def _merge_agent_run( @@ -90,8 +96,12 @@ class ParallelAgent(BaseAgent): async def _run_async_impl( self, ctx: InvocationContext ) -> AsyncGenerator[Event, None]: - _set_branch_for_current_agent(self, ctx) - agent_runs = [agent.run_async(ctx) for agent in self.sub_agents] + agent_runs = [ + sub_agent.run_async( + _create_branch_ctx_for_sub_agent(self, sub_agent, ctx) + ) + for sub_agent in self.sub_agents + ] async for event in _merge_agent_run(agent_runs): yield event diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index 624bd28..25aca8f 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -159,7 +159,7 @@ async def test_run_async_with_branch(request: pytest.FixtureRequest): assert len(events) == 1 assert events[0].author == agent.name assert events[0].content.parts[0].text == 'Hello, world!' - assert events[0].branch.endswith(agent.name) + assert events[0].branch == 'parent_branch' @pytest.mark.asyncio @@ -625,7 +625,7 @@ async def test_run_live_with_branch(request: pytest.FixtureRequest): assert len(events) == 1 assert events[0].author == agent.name assert events[0].content.parts[0].text == 'Hello, live!' - assert events[0].branch.endswith(agent.name) + assert events[0].branch == 'parent_branch' @pytest.mark.asyncio diff --git a/tests/unittests/agents/test_parallel_agent.py b/tests/unittests/agents/test_parallel_agent.py index 8b29987..ccfdae3 100644 --- a/tests/unittests/agents/test_parallel_agent.py +++ b/tests/unittests/agents/test_parallel_agent.py @@ -20,6 +20,7 @@ from typing import AsyncGenerator from google.adk.agents.base_agent import BaseAgent from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.parallel_agent import ParallelAgent +from google.adk.agents.sequential_agent import SequentialAgent from google.adk.events import Event from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.genai import types @@ -86,7 +87,51 @@ async def test_run_async(request: pytest.FixtureRequest): # and agent1 has a delay. assert events[0].author == agent2.name assert events[1].author == agent1.name - assert events[0].branch.endswith(agent2.name) - assert events[1].branch.endswith(agent1.name) + assert events[0].branch.endswith(f'{parallel_agent.name}.{agent2.name}') + assert events[1].branch.endswith(f'{parallel_agent.name}.{agent1.name}') assert events[0].content.parts[0].text == f'Hello, async {agent2.name}!' assert events[1].content.parts[0].text == f'Hello, async {agent1.name}!' + + +@pytest.mark.asyncio +async def test_run_async_branches(request: pytest.FixtureRequest): + agent1 = _TestingAgent( + name=f'{request.function.__name__}_test_agent_1', + delay=0.5, + ) + agent2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2') + agent3 = _TestingAgent(name=f'{request.function.__name__}_test_agent_3') + sequential_agent = SequentialAgent( + name=f'{request.function.__name__}_test_sequential_agent', + sub_agents=[agent2, agent3], + ) + parallel_agent = ParallelAgent( + name=f'{request.function.__name__}_test_parallel_agent', + sub_agents=[ + sequential_agent, + agent1, + ], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, parallel_agent + ) + events = [e async for e in parallel_agent.run_async(parent_ctx)] + + assert len(events) == 3 + assert ( + events[0].author == agent2.name + and events[0].branch == f'{parallel_agent.name}.{sequential_agent.name}' + ) + assert ( + events[1].author == agent3.name + and events[0].branch == f'{parallel_agent.name}.{sequential_agent.name}' + ) + # Descendants of the same sub-agent should have the same branch. + assert events[0].branch == events[1].branch + assert ( + events[2].author == agent1.name + and events[2].branch == f'{parallel_agent.name}.{agent1.name}' + ) + # Sub-agents should have different branches. + assert events[2].branch != events[1].branch + assert events[2].branch != events[0].branch