diff --git a/src/google/adk/agents/callback_context.py b/src/google/adk/agents/callback_context.py index 9d6e311..724d49a 100644 --- a/src/google/adk/agents/callback_context.py +++ b/src/google/adk/agents/callback_context.py @@ -60,11 +60,6 @@ class CallbackContext(ReadonlyContext): """ return self._state - @property - def user_content(self) -> Optional[types.Content]: - """The user content that started this invocation. READONLY field.""" - return self._invocation_context.user_content - async def load_artifact( self, filename: str, version: Optional[int] = None ) -> Optional[types.Part]: diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 4d4ceae..0076c6a 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -15,7 +15,15 @@ from __future__ import annotations import logging -from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Optional, Union +from typing import ( + Any, + AsyncGenerator, + Awaitable, + Callable, + Literal, + Optional, + Union, +) from google.genai import types from pydantic import BaseModel @@ -38,6 +46,7 @@ from ..models.llm_response import LlmResponse from ..models.registry import LLMRegistry from ..planners.base_planner import BasePlanner from ..tools.base_tool import BaseTool +from ..tools.base_toolset import BaseToolset from ..tools.function_tool import FunctionTool from ..tools.tool_context import ToolContext from .base_agent import BaseAgent @@ -89,18 +98,19 @@ AfterToolCallback: TypeAlias = Union[ InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str] -ToolUnion: TypeAlias = Union[Callable, BaseTool] +ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset] ExamplesUnion = Union[list[Example], BaseExampleProvider] -def _convert_tool_union_to_tool( - tool_union: ToolUnion, -) -> BaseTool: - return ( - tool_union - if isinstance(tool_union, BaseTool) - else FunctionTool(tool_union) - ) +async def _convert_tool_union_to_tools( + tool_union: ToolUnion, ctx: ReadonlyContext +) -> list[BaseTool]: + if isinstance(tool_union, BaseTool): + return [tool_union] + if isinstance(tool_union, Callable): + return [FunctionTool(func=tool_union)] + + return await tool_union.get_tools(ctx) class LlmAgent(BaseAgent): @@ -312,13 +322,17 @@ class LlmAgent(BaseAgent): else: return self.global_instruction(ctx) - @property - def canonical_tools(self) -> list[BaseTool]: - """The resolved self.tools field as a list of BaseTool. + async def canonical_tools( + self, ctx: ReadonlyContext = None + ) -> list[BaseTool]: + """The resolved self.tools field as a list of BaseTool based on the context. This method is only for use by Agent Development Kit. """ - return [_convert_tool_union_to_tool(tool) for tool in self.tools] + resolved_tools = [] + for tool_union in self.tools: + resolved_tools.extend(await _convert_tool_union_to_tools(tool_union, ctx)) + return resolved_tools @property def canonical_before_model_callbacks( diff --git a/src/google/adk/agents/readonly_context.py b/src/google/adk/agents/readonly_context.py index fb373cc..928e2d1 100644 --- a/src/google/adk/agents/readonly_context.py +++ b/src/google/adk/agents/readonly_context.py @@ -15,10 +15,11 @@ from __future__ import annotations from types import MappingProxyType -from typing import Any +from typing import Any, Optional from typing import TYPE_CHECKING if TYPE_CHECKING: + from google.genai import types from .invocation_context import InvocationContext @@ -30,6 +31,11 @@ class ReadonlyContext: ) -> None: self._invocation_context = invocation_context + @property + def user_content(self) -> Optional[types.Content]: + """The user content that started this invocation. READONLY field.""" + return self._invocation_context.user_content + @property def invocation_id(self) -> str: """The current invocation id.""" diff --git a/src/google/adk/auth/auth_preprocessor.py b/src/google/adk/auth/auth_preprocessor.py index 9a2f355..8ad30b7 100644 --- a/src/google/adk/auth/auth_preprocessor.py +++ b/src/google/adk/auth/auth_preprocessor.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING from typing_extensions import override from ..agents.invocation_context import InvocationContext +from ..agents.readonly_context import ReadonlyContext from ..events.event import Event from ..flows.llm_flows import functions from ..flows.llm_flows._base_llm_processor import BaseLlmRequestProcessor @@ -105,7 +106,12 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor): function_response_event = await functions.handle_function_calls_async( invocation_context, event, - {tool.name: tool for tool in agent.canonical_tools}, + { + tool.name: tool + for tool in await agent.canonical_tools( + ReadonlyContext(invocation_context) + ) + }, # there could be parallel function calls that require auth # auth response would be a dict keyed by function call id tools_to_resume, diff --git a/src/google/adk/cli/agent_graph.py b/src/google/adk/cli/agent_graph.py index ebc6191..c1b5fa5 100644 --- a/src/google/adk/cli/agent_graph.py +++ b/src/google/adk/cli/agent_graph.py @@ -35,7 +35,7 @@ else: retrieval_tool_module_loaded = True -def build_graph(graph, agent: BaseAgent, highlight_pairs): +async def build_graph(graph, agent: BaseAgent, highlight_pairs): dark_green = '#0F5223' light_green = '#69CB87' light_gray = '#cccccc' @@ -133,12 +133,12 @@ def build_graph(graph, agent: BaseAgent, highlight_pairs): build_graph(graph, sub_agent, highlight_pairs) draw_edge(agent.name, sub_agent.name) if isinstance(agent, LlmAgent): - for tool in agent.canonical_tools: + for tool in await agent.canonical_tools(): draw_node(tool) draw_edge(agent.name, get_node_name(tool)) -def get_agent_graph(root_agent, highlights_pairs, image=False): +async def get_agent_graph(root_agent, highlights_pairs, image=False): print('build graph') graph = graphviz.Digraph(graph_attr={'rankdir': 'LR', 'bgcolor': '#333537'}) build_graph(graph, root_agent, highlights_pairs) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index aa3de52..91cf3b3 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -24,13 +24,9 @@ import re import sys import traceback import typing -from typing import Any -from typing import List -from typing import Literal -from typing import Optional +from typing import Any, List, Literal, Optional, Union import click -from click import Tuple from fastapi import FastAPI from fastapi import HTTPException from fastapi import Query @@ -53,9 +49,10 @@ from pydantic import ValidationError from starlette.types import Lifespan from ..agents import RunConfig +from ..agents.base_agent import BaseAgent from ..agents.live_request_queue import LiveRequest from ..agents.live_request_queue import LiveRequestQueue -from ..agents.llm_agent import Agent +from ..agents.llm_agent import Agent, LlmAgent from ..agents.run_config import StreamingMode from ..artifacts import InMemoryArtifactService from ..events.event import Event @@ -65,6 +62,7 @@ from ..sessions.database_session_service import DatabaseSessionService from ..sessions.in_memory_session_service import InMemorySessionService from ..sessions.session import Session from ..sessions.vertex_ai_session_service import VertexAiSessionService +from ..tools.base_toolset import BaseToolset from .cli_eval import EVAL_SESSION_ID_PREFIX from .cli_eval import EvalMetric from .cli_eval import EvalMetricResult @@ -163,6 +161,7 @@ def get_fast_api_app( trace.set_tracer_provider(provider) exit_stacks = [] + toolsets_to_close: set[BaseToolset] = set() @asynccontextmanager async def internal_lifespan(app: FastAPI): @@ -173,6 +172,8 @@ def get_fast_api_app( if exit_stacks: for stack in exit_stacks: await stack.aclose() + for toolset in toolsets_to_close: + await toolset.close() else: yield @@ -673,7 +674,7 @@ def get_fast_api_app( from_name = event.author to_name = function_call.name function_call_highlights.append((from_name, to_name)) - dot_graph = agent_graph.get_agent_graph( + dot_graph = await agent_graph.get_agent_graph( root_agent, function_call_highlights ) elif function_responses: @@ -682,13 +683,13 @@ def get_fast_api_app( from_name = function_response.name to_name = event.author function_responses_highlights.append((from_name, to_name)) - dot_graph = agent_graph.get_agent_graph( + dot_graph = await agent_graph.get_agent_graph( root_agent, function_responses_highlights ) else: from_name = event.author to_name = "" - dot_graph = agent_graph.get_agent_graph( + dot_graph = await agent_graph.get_agent_graph( root_agent, [(from_name, to_name)] ) if dot_graph and isinstance(dot_graph, graphviz.Digraph): @@ -766,6 +767,16 @@ def get_fast_api_app( for task in pending: task.cancel() + def _get_all_toolsets(agent: BaseAgent) -> set[BaseToolset]: + toolsets = set() + if isinstance(agent, LlmAgent): + for tool_union in agent.tools: + if isinstance(tool_union, BaseToolset): + toolsets.add(tool_union) + for sub_agent in agent.sub_agents: + toolsets.update(_get_all_toolsets(sub_agent)) + return toolsets + async def _get_root_agent_async(app_name: str) -> Agent: """Returns the root agent for the given app.""" if app_name in root_agent_dict: @@ -786,6 +797,7 @@ def get_fast_api_app( raise RuntimeError(f"error getting root agent, {e}") from e root_agent_dict[app_name] = root_agent + toolsets_to_close.update(_get_all_toolsets(root_agent)) return root_agent async def _get_runner_async(app_name: str) -> Runner: diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index 779c8ad..09fcf26 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -258,7 +258,7 @@ class EvaluationGenerator: if not isinstance(agent, Agent) and not isinstance(agent, LlmAgent): return - for tool in agent.canonical_tools: + for tool in await agent.canonical_tools(): tool_name = tool.name if tool_name in all_mock_tools: agent.before_tool_callback = callback diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index acf4d54..5268034 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -29,6 +29,7 @@ from ...agents.base_agent import BaseAgent from ...agents.callback_context import CallbackContext from ...agents.invocation_context import InvocationContext from ...agents.live_request_queue import LiveRequestQueue +from ...agents.readonly_context import ReadonlyContext from ...agents.run_config import StreamingMode from ...agents.transcription_entry import TranscriptionEntry from ...events.event import Event @@ -296,7 +297,9 @@ class BaseLlmFlow(ABC): yield event # Run processors for tools. - for tool in agent.canonical_tools: + for tool in await agent.canonical_tools( + ReadonlyContext(invocation_context) + ): tool_context = ToolContext(invocation_context) await tool.process_llm_request( tool_context=tool_context, llm_request=llm_request diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 1ec8631..529c467 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -18,9 +18,7 @@ import asyncio import logging import queue import threading -from typing import AsyncGenerator -from typing import Generator -from typing import Optional +from typing import AsyncGenerator, Generator, Optional from deprecated import deprecated from google.genai import types @@ -391,7 +389,7 @@ class Runner: f'CFC is not supported for model: {model_name} in agent:' f' {self.agent.name}' ) - if built_in_code_execution not in self.agent.canonical_tools: + if built_in_code_execution not in self.agent.canonical_tools(): self.agent.tools.append(built_in_code_execution) return InvocationContext( diff --git a/src/google/adk/tools/base_toolset.py b/src/google/adk/tools/base_toolset.py new file mode 100644 index 0000000..d603a8e --- /dev/null +++ b/src/google/adk/tools/base_toolset.py @@ -0,0 +1,56 @@ +from abc import ABC +from abc import abstractmethod +from typing import Protocol + +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.tools.base_tool import BaseTool + + +class ToolPredicate(Protocol): + """Base class for a predicate that defines the interface to decide whether a + + tool should be exposed to LLM. Toolset implementer could consider whether to + accept such instance in the toolset's constructor and apply the predicate in + get_tools method. + """ + + def __call__( + self, tool: BaseTool, readonly_context: ReadonlyContext = None + ) -> bool: + """Decide whether the passed-in tool should be exposed to LLM based on the + + current context. True if the tool is usable by the LLM. + + It's used to filter tools in the toolset. + """ + + +class BaseToolset(ABC): + """Base class for toolset. + + A toolset is a collection of tools that can be used by an agent. + """ + + @abstractmethod + async def get_tools( + self, readony_context: ReadonlyContext = None + ) -> list[BaseTool]: + """Return all tools in the toolset based on the provided context. + + Args: + readony_context (ReadonlyContext, optional): Context used to filter tools + available to the agent. If None, all tools in the toolset are returned. + + Returns: + list[BaseTool]: A list of tools available under the specified context. + """ + + @abstractmethod + async def close(self) -> None: + """Performs cleanup and releases resources held by the toolset. + + NOTE: This method is invoked, for example, at the end of an agent server's + lifecycle or when the toolset is no longer needed. Implementations + should ensure that any open connections, files, or other managed + resources are properly released to prevent leaks. + """