Files
adk-python/src/google/adk/runners.py
Xiang (Sean) Zhou 4d7298e4f2 add toolset base class and allow llm agent to accept toolset as tools
PiperOrigin-RevId: 756605470
2025-05-08 22:27:52 -07:00

466 lines
15 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import logging
import queue
import threading
from typing import AsyncGenerator, Generator, Optional
from deprecated import deprecated
from google.genai import types
from .agents.active_streaming_tool import ActiveStreamingTool
from .agents.base_agent import BaseAgent
from .agents.invocation_context import InvocationContext
from .agents.invocation_context import new_invocation_context_id
from .agents.live_request_queue import LiveRequestQueue
from .agents.llm_agent import LlmAgent
from .agents.run_config import RunConfig
from .agents.run_config import StreamingMode
from .artifacts.base_artifact_service import BaseArtifactService
from .artifacts.in_memory_artifact_service import InMemoryArtifactService
from .events.event import Event
from .memory.base_memory_service import BaseMemoryService
from .memory.in_memory_memory_service import InMemoryMemoryService
from .sessions.base_session_service import BaseSessionService
from .sessions.in_memory_session_service import InMemorySessionService
from .sessions.session import Session
from .telemetry import tracer
from .tools.built_in_code_execution_tool import built_in_code_execution
logger = logging.getLogger(__name__)
class Runner:
"""The Runner class is used to run agents.
It manages the execution of an agent within a session, handling message
processing, event generation, and interaction with various services like
artifact storage, session management, and memory.
Attributes:
app_name: The application name of the runner.
agent: The root agent to run.
artifact_service: The artifact service for the runner.
session_service: The session service for the runner.
memory_service: The memory service for the runner.
"""
app_name: str
"""The app name of the runner."""
agent: BaseAgent
"""The root agent to run."""
artifact_service: Optional[BaseArtifactService] = None
"""The artifact service for the runner."""
session_service: BaseSessionService
"""The session service for the runner."""
memory_service: Optional[BaseMemoryService] = None
"""The memory service for the runner."""
def __init__(
self,
*,
app_name: str,
agent: BaseAgent,
artifact_service: Optional[BaseArtifactService] = None,
session_service: BaseSessionService,
memory_service: Optional[BaseMemoryService] = None,
):
"""Initializes the Runner.
Args:
app_name: The application name of the runner.
agent: The root agent to run.
artifact_service: The artifact service for the runner.
session_service: The session service for the runner.
memory_service: The memory service for the runner.
"""
self.app_name = app_name
self.agent = agent
self.artifact_service = artifact_service
self.session_service = session_service
self.memory_service = memory_service
def run(
self,
*,
user_id: str,
session_id: str,
new_message: types.Content,
run_config: RunConfig = RunConfig(),
) -> Generator[Event, None, None]:
"""Runs the agent.
NOTE: This sync interface is only for local testing and convenience purpose.
Consider using `run_async` for production usage.
Args:
user_id: The user ID of the session.
session_id: The session ID of the session.
new_message: A new message to append to the session.
run_config: The run config for the agent.
Yields:
The events generated by the agent.
"""
event_queue = queue.Queue()
async def _invoke_run_async():
try:
async for event in self.run_async(
user_id=user_id,
session_id=session_id,
new_message=new_message,
run_config=run_config,
):
event_queue.put(event)
finally:
event_queue.put(None)
def _asyncio_thread_main():
try:
asyncio.run(_invoke_run_async())
finally:
event_queue.put(None)
thread = threading.Thread(target=_asyncio_thread_main)
thread.start()
# consumes and re-yield the events from background thread.
while True:
event = event_queue.get()
if event is None:
break
else:
yield event
thread.join()
async def run_async(
self,
*,
user_id: str,
session_id: str,
new_message: types.Content,
run_config: RunConfig = RunConfig(),
) -> AsyncGenerator[Event, None]:
"""Main entry method to run the agent in this runner.
Args:
user_id: The user ID of the session.
session_id: The session ID of the session.
new_message: A new message to append to the session.
run_config: The run config for the agent.
Yields:
The events generated by the agent.
"""
with tracer.start_as_current_span('invocation'):
session = self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
)
if not session:
raise ValueError(f'Session not found: {session_id}')
invocation_context = self._new_invocation_context(
session,
new_message=new_message,
run_config=run_config,
)
root_agent = self.agent
if new_message:
await self._append_new_message_to_session(
session,
new_message,
invocation_context,
run_config.save_input_blobs_as_artifacts,
)
invocation_context.agent = self._find_agent_to_run(session, root_agent)
async for event in invocation_context.agent.run_async(invocation_context):
if not event.partial:
self.session_service.append_event(session=session, event=event)
yield event
async def _append_new_message_to_session(
self,
session: Session,
new_message: types.Content,
invocation_context: InvocationContext,
save_input_blobs_as_artifacts: bool = False,
):
"""Appends a new message to the session.
Args:
session: The session to append the message to.
new_message: The new message to append.
invocation_context: The invocation context for the message.
save_input_blobs_as_artifacts: Whether to save input blobs as artifacts.
"""
if not new_message.parts:
raise ValueError('No parts in the new_message.')
if self.artifact_service and save_input_blobs_as_artifacts:
# The runner directly saves the artifacts (if applicable) in the
# user message and replaces the artifact data with a file name
# placeholder.
for i, part in enumerate(new_message.parts):
if part.inline_data is None:
continue
file_name = f'artifact_{invocation_context.invocation_id}_{i}'
await self.artifact_service.save_artifact(
app_name=self.app_name,
user_id=session.user_id,
session_id=session.id,
filename=file_name,
artifact=part,
)
new_message.parts[i] = types.Part(
text=f'Uploaded file: {file_name}. It is saved into artifacts'
)
# Appends only. We do not yield the event because it's not from the model.
event = Event(
invocation_id=invocation_context.invocation_id,
author='user',
content=new_message,
)
self.session_service.append_event(session=session, event=event)
async def run_live(
self,
*,
session: Session,
live_request_queue: LiveRequestQueue,
run_config: RunConfig = RunConfig(),
) -> AsyncGenerator[Event, None]:
"""Runs the agent in live mode (experimental feature).
Args:
session: The session to use.
live_request_queue: The queue for live requests.
run_config: The run config for the agent.
Yields:
The events generated by the agent.
.. warning::
This feature is **experimental** and its API or behavior may change
in future releases.
"""
# TODO: right now, only works for a single audio agent without FC.
invocation_context = self._new_invocation_context_for_live(
session,
live_request_queue=live_request_queue,
run_config=run_config,
)
root_agent = self.agent
invocation_context.agent = self._find_agent_to_run(session, root_agent)
invocation_context.active_streaming_tools = {}
# TODO(hangfei): switch to use canonical_tools.
for tool in invocation_context.agent.tools:
# replicate a LiveRequestQueue for streaming tools that relis on
# LiveRequestQueue
from typing import get_type_hints
type_hints = get_type_hints(tool)
for arg_type in type_hints.values():
if arg_type is LiveRequestQueue:
if not invocation_context.active_streaming_tools:
invocation_context.active_streaming_tools = {}
active_streaming_tools = ActiveStreamingTool(
stream=LiveRequestQueue()
)
invocation_context.active_streaming_tools[tool.__name__] = (
active_streaming_tools
)
async for event in invocation_context.agent.run_live(invocation_context):
self.session_service.append_event(session=session, event=event)
yield event
async def close_session(self, session: Session):
"""Closes a session and adds it to the memory service (experimental feature).
Args:
session: The session to close.
"""
if self.memory_service:
await self.memory_service.add_session_to_memory(session)
self.session_service.close_session(session=session)
def _find_agent_to_run(
self, session: Session, root_agent: BaseAgent
) -> BaseAgent:
"""Finds the agent to run to continue the session.
A qualified agent must be either of:
- The root agent;
- An LlmAgent who replied last and is capable to transfer to any other agent
in the agent hierarchy.
Args:
session: The session to find the agent for.
root_agent: The root agent of the runner.
Returns:
The agent of the last message in the session or the root agent.
"""
for event in filter(lambda e: e.author != 'user', reversed(session.events)):
if event.author == root_agent.name:
# Found root agent.
return root_agent
if not (agent := root_agent.find_sub_agent(event.author)):
# Agent not found, continue looking.
logger.warning(
'Event from an unknown agent: %s, event id: %s',
event.author,
event.id,
)
continue
if self._is_transferable_across_agent_tree(agent):
return agent
# Falls back to root agent if no suitable agents are found in the session.
return root_agent
def _is_transferable_across_agent_tree(self, agent_to_run: BaseAgent) -> bool:
"""Whether the agent to run can transfer to any other agent in the agent tree.
This typically means all agent_to_run's parent through root agent can
transfer to their parent_agent.
Args:
agent_to_run: The agent to check for transferability.
Returns:
True if the agent can transfer, False otherwise.
"""
agent = agent_to_run
while agent:
if not isinstance(agent, LlmAgent):
# Only LLM-based Agent can provider agent transfer capability.
return False
if agent.disallow_transfer_to_parent:
return False
agent = agent.parent_agent
return True
def _new_invocation_context(
self,
session: Session,
*,
new_message: Optional[types.Content] = None,
live_request_queue: Optional[LiveRequestQueue] = None,
run_config: RunConfig = RunConfig(),
) -> InvocationContext:
"""Creates a new invocation context.
Args:
session: The session for the context.
new_message: The new message for the context.
live_request_queue: The live request queue for the context.
run_config: The run config for the context.
Returns:
The new invocation context.
"""
invocation_id = new_invocation_context_id()
if run_config.support_cfc and isinstance(self.agent, LlmAgent):
model_name = self.agent.canonical_model.model
if not model_name.startswith('gemini-2'):
raise ValueError(
f'CFC is not supported for model: {model_name} in agent:'
f' {self.agent.name}'
)
if built_in_code_execution not in self.agent.canonical_tools():
self.agent.tools.append(built_in_code_execution)
return InvocationContext(
artifact_service=self.artifact_service,
session_service=self.session_service,
memory_service=self.memory_service,
invocation_id=invocation_id,
agent=self.agent,
session=session,
user_content=new_message,
live_request_queue=live_request_queue,
run_config=run_config,
)
def _new_invocation_context_for_live(
self,
session: Session,
*,
live_request_queue: Optional[LiveRequestQueue] = None,
run_config: RunConfig = RunConfig(),
) -> InvocationContext:
"""Creates a new invocation context for live multi-agent."""
# For live multi-agent, we need model's text transcription as context for
# next agent.
if self.agent.sub_agents and live_request_queue:
if not run_config.response_modalities:
# default
run_config.response_modalities = ['AUDIO']
if not run_config.output_audio_transcription:
run_config.output_audio_transcription = (
types.AudioTranscriptionConfig()
)
elif 'TEXT' not in run_config.response_modalities:
if not run_config.output_audio_transcription:
run_config.output_audio_transcription = (
types.AudioTranscriptionConfig()
)
return self._new_invocation_context(
session,
live_request_queue=live_request_queue,
run_config=run_config,
)
class InMemoryRunner(Runner):
"""An in-memory Runner for testing and development.
This runner uses in-memory implementations for artifact, session, and memory
services, providing a lightweight and self-contained environment for agent
execution.
Attributes:
agent: The root agent to run.
app_name: The application name of the runner. Defaults to
'InMemoryRunner'.
"""
def __init__(self, agent: LlmAgent, *, app_name: str = 'InMemoryRunner'):
"""Initializes the InMemoryRunner.
Args:
agent: The root agent to run.
app_name: The application name of the runner. Defaults to
'InMemoryRunner'.
"""
super().__init__(
app_name=app_name,
agent=agent,
artifact_service=InMemoryArtifactService(),
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
)