add toolset base class and allow llm agent to accept toolset as tools

PiperOrigin-RevId: 756605470
This commit is contained in:
Xiang (Sean) Zhou 2025-05-08 22:27:13 -07:00 committed by Copybara-Service
parent 8963300518
commit 4d7298e4f2
10 changed files with 129 additions and 39 deletions

View File

@ -60,11 +60,6 @@ class CallbackContext(ReadonlyContext):
""" """
return self._state return self._state
@property
def user_content(self) -> Optional[types.Content]:
"""The user content that started this invocation. READONLY field."""
return self._invocation_context.user_content
async def load_artifact( async def load_artifact(
self, filename: str, version: Optional[int] = None self, filename: str, version: Optional[int] = None
) -> Optional[types.Part]: ) -> Optional[types.Part]:

View File

@ -15,7 +15,15 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Optional, Union from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Literal,
Optional,
Union,
)
from google.genai import types from google.genai import types
from pydantic import BaseModel from pydantic import BaseModel
@ -38,6 +46,7 @@ from ..models.llm_response import LlmResponse
from ..models.registry import LLMRegistry from ..models.registry import LLMRegistry
from ..planners.base_planner import BasePlanner from ..planners.base_planner import BasePlanner
from ..tools.base_tool import BaseTool from ..tools.base_tool import BaseTool
from ..tools.base_toolset import BaseToolset
from ..tools.function_tool import FunctionTool from ..tools.function_tool import FunctionTool
from ..tools.tool_context import ToolContext from ..tools.tool_context import ToolContext
from .base_agent import BaseAgent from .base_agent import BaseAgent
@ -89,18 +98,19 @@ AfterToolCallback: TypeAlias = Union[
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str] InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
ToolUnion: TypeAlias = Union[Callable, BaseTool] ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset]
ExamplesUnion = Union[list[Example], BaseExampleProvider] ExamplesUnion = Union[list[Example], BaseExampleProvider]
def _convert_tool_union_to_tool( async def _convert_tool_union_to_tools(
tool_union: ToolUnion, tool_union: ToolUnion, ctx: ReadonlyContext
) -> BaseTool: ) -> list[BaseTool]:
return ( if isinstance(tool_union, BaseTool):
tool_union return [tool_union]
if isinstance(tool_union, BaseTool) if isinstance(tool_union, Callable):
else FunctionTool(tool_union) return [FunctionTool(func=tool_union)]
)
return await tool_union.get_tools(ctx)
class LlmAgent(BaseAgent): class LlmAgent(BaseAgent):
@ -312,13 +322,17 @@ class LlmAgent(BaseAgent):
else: else:
return self.global_instruction(ctx) return self.global_instruction(ctx)
@property async def canonical_tools(
def canonical_tools(self) -> list[BaseTool]: self, ctx: ReadonlyContext = None
"""The resolved self.tools field as a list of BaseTool. ) -> list[BaseTool]:
"""The resolved self.tools field as a list of BaseTool based on the context.
This method is only for use by Agent Development Kit. This method is only for use by Agent Development Kit.
""" """
return [_convert_tool_union_to_tool(tool) for tool in self.tools] resolved_tools = []
for tool_union in self.tools:
resolved_tools.extend(await _convert_tool_union_to_tools(tool_union, ctx))
return resolved_tools
@property @property
def canonical_before_model_callbacks( def canonical_before_model_callbacks(

View File

@ -15,10 +15,11 @@
from __future__ import annotations from __future__ import annotations
from types import MappingProxyType from types import MappingProxyType
from typing import Any from typing import Any, Optional
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from google.genai import types
from .invocation_context import InvocationContext from .invocation_context import InvocationContext
@ -30,6 +31,11 @@ class ReadonlyContext:
) -> None: ) -> None:
self._invocation_context = invocation_context self._invocation_context = invocation_context
@property
def user_content(self) -> Optional[types.Content]:
"""The user content that started this invocation. READONLY field."""
return self._invocation_context.user_content
@property @property
def invocation_id(self) -> str: def invocation_id(self) -> str:
"""The current invocation id.""" """The current invocation id."""

View File

@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
from typing_extensions import override from typing_extensions import override
from ..agents.invocation_context import InvocationContext from ..agents.invocation_context import InvocationContext
from ..agents.readonly_context import ReadonlyContext
from ..events.event import Event from ..events.event import Event
from ..flows.llm_flows import functions from ..flows.llm_flows import functions
from ..flows.llm_flows._base_llm_processor import BaseLlmRequestProcessor from ..flows.llm_flows._base_llm_processor import BaseLlmRequestProcessor
@ -105,7 +106,12 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
function_response_event = await functions.handle_function_calls_async( function_response_event = await functions.handle_function_calls_async(
invocation_context, invocation_context,
event, event,
{tool.name: tool for tool in agent.canonical_tools}, {
tool.name: tool
for tool in await agent.canonical_tools(
ReadonlyContext(invocation_context)
)
},
# there could be parallel function calls that require auth # there could be parallel function calls that require auth
# auth response would be a dict keyed by function call id # auth response would be a dict keyed by function call id
tools_to_resume, tools_to_resume,

View File

@ -35,7 +35,7 @@ else:
retrieval_tool_module_loaded = True retrieval_tool_module_loaded = True
def build_graph(graph, agent: BaseAgent, highlight_pairs): async def build_graph(graph, agent: BaseAgent, highlight_pairs):
dark_green = '#0F5223' dark_green = '#0F5223'
light_green = '#69CB87' light_green = '#69CB87'
light_gray = '#cccccc' light_gray = '#cccccc'
@ -133,12 +133,12 @@ def build_graph(graph, agent: BaseAgent, highlight_pairs):
build_graph(graph, sub_agent, highlight_pairs) build_graph(graph, sub_agent, highlight_pairs)
draw_edge(agent.name, sub_agent.name) draw_edge(agent.name, sub_agent.name)
if isinstance(agent, LlmAgent): if isinstance(agent, LlmAgent):
for tool in agent.canonical_tools: for tool in await agent.canonical_tools():
draw_node(tool) draw_node(tool)
draw_edge(agent.name, get_node_name(tool)) draw_edge(agent.name, get_node_name(tool))
def get_agent_graph(root_agent, highlights_pairs, image=False): async def get_agent_graph(root_agent, highlights_pairs, image=False):
print('build graph') print('build graph')
graph = graphviz.Digraph(graph_attr={'rankdir': 'LR', 'bgcolor': '#333537'}) graph = graphviz.Digraph(graph_attr={'rankdir': 'LR', 'bgcolor': '#333537'})
build_graph(graph, root_agent, highlights_pairs) build_graph(graph, root_agent, highlights_pairs)

View File

@ -24,13 +24,9 @@ import re
import sys import sys
import traceback import traceback
import typing import typing
from typing import Any from typing import Any, List, Literal, Optional, Union
from typing import List
from typing import Literal
from typing import Optional
import click import click
from click import Tuple
from fastapi import FastAPI from fastapi import FastAPI
from fastapi import HTTPException from fastapi import HTTPException
from fastapi import Query from fastapi import Query
@ -53,9 +49,10 @@ from pydantic import ValidationError
from starlette.types import Lifespan from starlette.types import Lifespan
from ..agents import RunConfig from ..agents import RunConfig
from ..agents.base_agent import BaseAgent
from ..agents.live_request_queue import LiveRequest from ..agents.live_request_queue import LiveRequest
from ..agents.live_request_queue import LiveRequestQueue from ..agents.live_request_queue import LiveRequestQueue
from ..agents.llm_agent import Agent from ..agents.llm_agent import Agent, LlmAgent
from ..agents.run_config import StreamingMode from ..agents.run_config import StreamingMode
from ..artifacts import InMemoryArtifactService from ..artifacts import InMemoryArtifactService
from ..events.event import Event from ..events.event import Event
@ -65,6 +62,7 @@ from ..sessions.database_session_service import DatabaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService from ..sessions.in_memory_session_service import InMemorySessionService
from ..sessions.session import Session from ..sessions.session import Session
from ..sessions.vertex_ai_session_service import VertexAiSessionService from ..sessions.vertex_ai_session_service import VertexAiSessionService
from ..tools.base_toolset import BaseToolset
from .cli_eval import EVAL_SESSION_ID_PREFIX from .cli_eval import EVAL_SESSION_ID_PREFIX
from .cli_eval import EvalMetric from .cli_eval import EvalMetric
from .cli_eval import EvalMetricResult from .cli_eval import EvalMetricResult
@ -163,6 +161,7 @@ def get_fast_api_app(
trace.set_tracer_provider(provider) trace.set_tracer_provider(provider)
exit_stacks = [] exit_stacks = []
toolsets_to_close: set[BaseToolset] = set()
@asynccontextmanager @asynccontextmanager
async def internal_lifespan(app: FastAPI): async def internal_lifespan(app: FastAPI):
@ -173,6 +172,8 @@ def get_fast_api_app(
if exit_stacks: if exit_stacks:
for stack in exit_stacks: for stack in exit_stacks:
await stack.aclose() await stack.aclose()
for toolset in toolsets_to_close:
await toolset.close()
else: else:
yield yield
@ -673,7 +674,7 @@ def get_fast_api_app(
from_name = event.author from_name = event.author
to_name = function_call.name to_name = function_call.name
function_call_highlights.append((from_name, to_name)) function_call_highlights.append((from_name, to_name))
dot_graph = agent_graph.get_agent_graph( dot_graph = await agent_graph.get_agent_graph(
root_agent, function_call_highlights root_agent, function_call_highlights
) )
elif function_responses: elif function_responses:
@ -682,13 +683,13 @@ def get_fast_api_app(
from_name = function_response.name from_name = function_response.name
to_name = event.author to_name = event.author
function_responses_highlights.append((from_name, to_name)) function_responses_highlights.append((from_name, to_name))
dot_graph = agent_graph.get_agent_graph( dot_graph = await agent_graph.get_agent_graph(
root_agent, function_responses_highlights root_agent, function_responses_highlights
) )
else: else:
from_name = event.author from_name = event.author
to_name = "" to_name = ""
dot_graph = agent_graph.get_agent_graph( dot_graph = await agent_graph.get_agent_graph(
root_agent, [(from_name, to_name)] root_agent, [(from_name, to_name)]
) )
if dot_graph and isinstance(dot_graph, graphviz.Digraph): if dot_graph and isinstance(dot_graph, graphviz.Digraph):
@ -766,6 +767,16 @@ def get_fast_api_app(
for task in pending: for task in pending:
task.cancel() task.cancel()
def _get_all_toolsets(agent: BaseAgent) -> set[BaseToolset]:
toolsets = set()
if isinstance(agent, LlmAgent):
for tool_union in agent.tools:
if isinstance(tool_union, BaseToolset):
toolsets.add(tool_union)
for sub_agent in agent.sub_agents:
toolsets.update(_get_all_toolsets(sub_agent))
return toolsets
async def _get_root_agent_async(app_name: str) -> Agent: async def _get_root_agent_async(app_name: str) -> Agent:
"""Returns the root agent for the given app.""" """Returns the root agent for the given app."""
if app_name in root_agent_dict: if app_name in root_agent_dict:
@ -786,6 +797,7 @@ def get_fast_api_app(
raise RuntimeError(f"error getting root agent, {e}") from e raise RuntimeError(f"error getting root agent, {e}") from e
root_agent_dict[app_name] = root_agent root_agent_dict[app_name] = root_agent
toolsets_to_close.update(_get_all_toolsets(root_agent))
return root_agent return root_agent
async def _get_runner_async(app_name: str) -> Runner: async def _get_runner_async(app_name: str) -> Runner:

View File

@ -258,7 +258,7 @@ class EvaluationGenerator:
if not isinstance(agent, Agent) and not isinstance(agent, LlmAgent): if not isinstance(agent, Agent) and not isinstance(agent, LlmAgent):
return return
for tool in agent.canonical_tools: for tool in await agent.canonical_tools():
tool_name = tool.name tool_name = tool.name
if tool_name in all_mock_tools: if tool_name in all_mock_tools:
agent.before_tool_callback = callback agent.before_tool_callback = callback

View File

@ -29,6 +29,7 @@ from ...agents.base_agent import BaseAgent
from ...agents.callback_context import CallbackContext from ...agents.callback_context import CallbackContext
from ...agents.invocation_context import InvocationContext from ...agents.invocation_context import InvocationContext
from ...agents.live_request_queue import LiveRequestQueue from ...agents.live_request_queue import LiveRequestQueue
from ...agents.readonly_context import ReadonlyContext
from ...agents.run_config import StreamingMode from ...agents.run_config import StreamingMode
from ...agents.transcription_entry import TranscriptionEntry from ...agents.transcription_entry import TranscriptionEntry
from ...events.event import Event from ...events.event import Event
@ -296,7 +297,9 @@ class BaseLlmFlow(ABC):
yield event yield event
# Run processors for tools. # Run processors for tools.
for tool in agent.canonical_tools: for tool in await agent.canonical_tools(
ReadonlyContext(invocation_context)
):
tool_context = ToolContext(invocation_context) tool_context = ToolContext(invocation_context)
await tool.process_llm_request( await tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request tool_context=tool_context, llm_request=llm_request

View File

@ -18,9 +18,7 @@ import asyncio
import logging import logging
import queue import queue
import threading import threading
from typing import AsyncGenerator from typing import AsyncGenerator, Generator, Optional
from typing import Generator
from typing import Optional
from deprecated import deprecated from deprecated import deprecated
from google.genai import types from google.genai import types
@ -391,7 +389,7 @@ class Runner:
f'CFC is not supported for model: {model_name} in agent:' f'CFC is not supported for model: {model_name} in agent:'
f' {self.agent.name}' f' {self.agent.name}'
) )
if built_in_code_execution not in self.agent.canonical_tools: if built_in_code_execution not in self.agent.canonical_tools():
self.agent.tools.append(built_in_code_execution) self.agent.tools.append(built_in_code_execution)
return InvocationContext( return InvocationContext(

View File

@ -0,0 +1,56 @@
from abc import ABC
from abc import abstractmethod
from typing import Protocol
from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.tools.base_tool import BaseTool
class ToolPredicate(Protocol):
"""Base class for a predicate that defines the interface to decide whether a
tool should be exposed to LLM. Toolset implementer could consider whether to
accept such instance in the toolset's constructor and apply the predicate in
get_tools method.
"""
def __call__(
self, tool: BaseTool, readonly_context: ReadonlyContext = None
) -> bool:
"""Decide whether the passed-in tool should be exposed to LLM based on the
current context. True if the tool is usable by the LLM.
It's used to filter tools in the toolset.
"""
class BaseToolset(ABC):
"""Base class for toolset.
A toolset is a collection of tools that can be used by an agent.
"""
@abstractmethod
async def get_tools(
self, readony_context: ReadonlyContext = None
) -> list[BaseTool]:
"""Return all tools in the toolset based on the provided context.
Args:
readony_context (ReadonlyContext, optional): Context used to filter tools
available to the agent. If None, all tools in the toolset are returned.
Returns:
list[BaseTool]: A list of tools available under the specified context.
"""
@abstractmethod
async def close(self) -> None:
"""Performs cleanup and releases resources held by the toolset.
NOTE: This method is invoked, for example, at the end of an agent server's
lifecycle or when the toolset is no longer needed. Implementations
should ensure that any open connections, files, or other managed
resources are properly released to prevent leaks.
"""