diff --git a/contributing/samples/simple_sequential_agent/__init__.py b/contributing/samples/simple_sequential_agent/__init__.py new file mode 100644 index 0000000..c48963c --- /dev/null +++ b/contributing/samples/simple_sequential_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/simple_sequential_agent/agent.py b/contributing/samples/simple_sequential_agent/agent.py new file mode 100644 index 0000000..74e8f58 --- /dev/null +++ b/contributing/samples/simple_sequential_agent/agent.py @@ -0,0 +1,94 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.sequential_agent import SequentialAgent +from google.genai import types + + +# --- Roll Die Sub-Agent --- +def roll_die(sides: int) -> int: + """Roll a die and return the rolled result.""" + return random.randint(1, sides) + + +roll_agent = LlmAgent( + name="roll_agent", + description="Handles rolling dice of different sizes.", + model="gemini-2.0-flash-exp", + instruction=""" + You are responsible for rolling dice based on the user's request. + When asked to roll a die, you must call the roll_die tool with the number of sides as an integer. + """, + tools=[roll_die], + generate_content_config=types.GenerateContentConfig( + safety_settings=[ + types.SafetySetting( # avoid false alarm about rolling dice. + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=types.HarmBlockThreshold.OFF, + ), + ] + ), +) + + +def check_prime(nums: list[int]) -> str: + """Check if a given list of numbers are prime.""" + primes = set() + for number in nums: + number = int(number) + if number <= 1: + continue + is_prime = True + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + is_prime = False + break + if is_prime: + primes.add(number) + return ( + "No prime numbers found." + if not primes + else f"{', '.join(str(num) for num in primes)} are prime numbers." + ) + + +prime_agent = LlmAgent( + name="prime_agent", + description="Handles checking if numbers are prime.", + model="gemini-2.0-flash-exp", + instruction=""" + You are responsible for checking whether numbers are prime. + When asked to check primes, you must call the check_prime tool with a list of integers. + Never attempt to determine prime numbers manually. + Return the prime number results to the root agent. + """, + tools=[check_prime], + generate_content_config=types.GenerateContentConfig( + safety_settings=[ + types.SafetySetting( # avoid false alarm about rolling dice. + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=types.HarmBlockThreshold.OFF, + ), + ] + ), +) + +root_agent = SequentialAgent( + name="code_pipeline_agent", + sub_agents=[roll_agent, prime_agent], + # The agents will run in the order provided: roll_agent -> prime_agent +) diff --git a/src/google/adk/agents/loop_agent.py b/src/google/adk/agents/loop_agent.py index c760c37..219e0c2 100644 --- a/src/google/adk/agents/loop_agent.py +++ b/src/google/adk/agents/loop_agent.py @@ -58,5 +58,5 @@ class LoopAgent(BaseAgent): async def _run_live_impl( self, ctx: InvocationContext ) -> AsyncGenerator[Event, None]: - raise NotImplementedError('The behavior for run_live is not defined yet.') + raise NotImplementedError('This is not supported yet for LoopAgent.') yield # AsyncGenerator requires having at least one yield statement diff --git a/src/google/adk/agents/parallel_agent.py b/src/google/adk/agents/parallel_agent.py index 4647fd4..61ca41b 100644 --- a/src/google/adk/agents/parallel_agent.py +++ b/src/google/adk/agents/parallel_agent.py @@ -94,3 +94,10 @@ class ParallelAgent(BaseAgent): agent_runs = [agent.run_async(ctx) for agent in self.sub_agents] async for event in _merge_agent_run(agent_runs): yield event + + @override + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + raise NotImplementedError("This is not supported yet for ParallelAgent.") + yield # AsyncGenerator requires having at least one yield statement diff --git a/src/google/adk/agents/sequential_agent.py b/src/google/adk/agents/sequential_agent.py index 8dabcff..d745f67 100644 --- a/src/google/adk/agents/sequential_agent.py +++ b/src/google/adk/agents/sequential_agent.py @@ -23,6 +23,7 @@ from typing_extensions import override from ..agents.invocation_context import InvocationContext from ..events.event import Event from .base_agent import BaseAgent +from .llm_agent import LlmAgent class SequentialAgent(BaseAgent): @@ -40,6 +41,36 @@ class SequentialAgent(BaseAgent): async def _run_live_impl( self, ctx: InvocationContext ) -> AsyncGenerator[Event, None]: + """Implementation for live SequentialAgent. + + Compared to non-live case, live agents process a continous streams of audio + or video, so it doesn't have a way to tell if it's finished and should pass + to next agent or not. So we introduce a task_compelted() function so the + model can call this function to signal that it's finished the task and we + can move on to next agent. + + Args: + ctx: The invocation context of the agent. + """ + # There is no way to know if it's using live during init phase so we have to init it here + for sub_agent in self.sub_agents: + # add tool + def task_completed(): + """ + Signals that the model has successfully completed the user's question + or task. + """ + return "Task completion signaled." + + if isinstance(sub_agent, LlmAgent): + # Use function name to dedupe. + if task_completed.__name__ not in sub_agent.tools: + sub_agent.tools.append(task_completed) + sub_agent.instruction += f"""If you finished the user' request + according to its description, call {task_completed.__name__} function + to exit so the next agents can take over. When calling this function, + do not generate any text other than the function call.'""" + for sub_agent in self.sub_agents: async for event in sub_agent.run_live(ctx): yield event diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 6b7caef..b6b45fc 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -135,6 +135,18 @@ class BaseLlmFlow(ABC): # cancel the tasks that belongs to the closed connection. send_task.cancel() await llm_connection.close() + if ( + event.content + and event.content.parts + and event.content.parts[0].function_response + and event.content.parts[0].function_response.name + == 'task_completed' + ): + # this is used for sequential agent to signal the end of the agent. + await asyncio.sleep(1) + # cancel the tasks that belongs to the closed connection. + send_task.cancel() + return finally: # Clean up if not send_task.done(): @@ -237,7 +249,7 @@ class BaseLlmFlow(ABC): if ( event.content and event.content.parts - and event.content.parts[0].text + and event.content.parts[0].inline_data is None and not event.partial ): # This can be either user data or transcription data. diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 1209e03..4de3acc 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -254,13 +254,13 @@ class Runner: """Runs the agent in live mode (experimental feature). Args: - session: The session to use. This parameter is deprecated, please use - `user_id` and `session_id` instead. user_id: The user ID for the session. Required if `session` is None. session_id: The session ID for the session. Required if `session` is None. live_request_queue: The queue for live requests. run_config: The run config for the agent. + session: The session to use. This parameter is deprecated, please use + `user_id` and `session_id` instead. Yields: AsyncGenerator[Event, None]: An asynchronous generator that yields @@ -302,22 +302,24 @@ class Runner: invocation_context.active_streaming_tools = {} # TODO(hangfei): switch to use canonical_tools. - for tool in invocation_context.agent.tools: - # replicate a LiveRequestQueue for streaming tools that relis on - # LiveRequestQueue - from typing import get_type_hints + # for shell agents, there is no tools associated with it so we should skip. + if hasattr(invocation_context.agent, 'tools'): + for tool in invocation_context.agent.tools: + # replicate a LiveRequestQueue for streaming tools that relis on + # LiveRequestQueue + from typing import get_type_hints - type_hints = get_type_hints(tool) - for arg_type in type_hints.values(): - if arg_type is LiveRequestQueue: - if not invocation_context.active_streaming_tools: - invocation_context.active_streaming_tools = {} - active_streaming_tools = ActiveStreamingTool( - stream=LiveRequestQueue() - ) - invocation_context.active_streaming_tools[tool.__name__] = ( - active_streaming_tools - ) + type_hints = get_type_hints(tool) + for arg_type in type_hints.values(): + if arg_type is LiveRequestQueue: + if not invocation_context.active_streaming_tools: + invocation_context.active_streaming_tools = {} + active_streaming_tools = ActiveStreamingTool( + stream=LiveRequestQueue() + ) + invocation_context.active_streaming_tools[tool.__name__] = ( + active_streaming_tools + ) async for event in invocation_context.agent.run_live(invocation_context): self.session_service.append_event(session=session, event=event)