add toolset base class and allow llm agent to accept toolset as tools

PiperOrigin-RevId: 756605470
This commit is contained in:
Xiang (Sean) Zhou 2025-05-08 22:27:13 -07:00 committed by Copybara-Service
parent 8963300518
commit 4d7298e4f2
10 changed files with 129 additions and 39 deletions

View File

@ -60,11 +60,6 @@ class CallbackContext(ReadonlyContext):
"""
return self._state
@property
def user_content(self) -> Optional[types.Content]:
"""The user content that started this invocation. READONLY field."""
return self._invocation_context.user_content
async def load_artifact(
self, filename: str, version: Optional[int] = None
) -> Optional[types.Part]:

View File

@ -15,7 +15,15 @@
from __future__ import annotations
import logging
from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Optional, Union
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Literal,
Optional,
Union,
)
from google.genai import types
from pydantic import BaseModel
@ -38,6 +46,7 @@ from ..models.llm_response import LlmResponse
from ..models.registry import LLMRegistry
from ..planners.base_planner import BasePlanner
from ..tools.base_tool import BaseTool
from ..tools.base_toolset import BaseToolset
from ..tools.function_tool import FunctionTool
from ..tools.tool_context import ToolContext
from .base_agent import BaseAgent
@ -89,18 +98,19 @@ AfterToolCallback: TypeAlias = Union[
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
ToolUnion: TypeAlias = Union[Callable, BaseTool]
ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset]
ExamplesUnion = Union[list[Example], BaseExampleProvider]
def _convert_tool_union_to_tool(
tool_union: ToolUnion,
) -> BaseTool:
return (
tool_union
if isinstance(tool_union, BaseTool)
else FunctionTool(tool_union)
)
async def _convert_tool_union_to_tools(
tool_union: ToolUnion, ctx: ReadonlyContext
) -> list[BaseTool]:
if isinstance(tool_union, BaseTool):
return [tool_union]
if isinstance(tool_union, Callable):
return [FunctionTool(func=tool_union)]
return await tool_union.get_tools(ctx)
class LlmAgent(BaseAgent):
@ -312,13 +322,17 @@ class LlmAgent(BaseAgent):
else:
return self.global_instruction(ctx)
@property
def canonical_tools(self) -> list[BaseTool]:
"""The resolved self.tools field as a list of BaseTool.
async def canonical_tools(
self, ctx: ReadonlyContext = None
) -> list[BaseTool]:
"""The resolved self.tools field as a list of BaseTool based on the context.
This method is only for use by Agent Development Kit.
"""
return [_convert_tool_union_to_tool(tool) for tool in self.tools]
resolved_tools = []
for tool_union in self.tools:
resolved_tools.extend(await _convert_tool_union_to_tools(tool_union, ctx))
return resolved_tools
@property
def canonical_before_model_callbacks(

View File

@ -15,10 +15,11 @@
from __future__ import annotations
from types import MappingProxyType
from typing import Any
from typing import Any, Optional
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from google.genai import types
from .invocation_context import InvocationContext
@ -30,6 +31,11 @@ class ReadonlyContext:
) -> None:
self._invocation_context = invocation_context
@property
def user_content(self) -> Optional[types.Content]:
"""The user content that started this invocation. READONLY field."""
return self._invocation_context.user_content
@property
def invocation_id(self) -> str:
"""The current invocation id."""

View File

@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
from typing_extensions import override
from ..agents.invocation_context import InvocationContext
from ..agents.readonly_context import ReadonlyContext
from ..events.event import Event
from ..flows.llm_flows import functions
from ..flows.llm_flows._base_llm_processor import BaseLlmRequestProcessor
@ -105,7 +106,12 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
function_response_event = await functions.handle_function_calls_async(
invocation_context,
event,
{tool.name: tool for tool in agent.canonical_tools},
{
tool.name: tool
for tool in await agent.canonical_tools(
ReadonlyContext(invocation_context)
)
},
# there could be parallel function calls that require auth
# auth response would be a dict keyed by function call id
tools_to_resume,

View File

@ -35,7 +35,7 @@ else:
retrieval_tool_module_loaded = True
def build_graph(graph, agent: BaseAgent, highlight_pairs):
async def build_graph(graph, agent: BaseAgent, highlight_pairs):
dark_green = '#0F5223'
light_green = '#69CB87'
light_gray = '#cccccc'
@ -133,12 +133,12 @@ def build_graph(graph, agent: BaseAgent, highlight_pairs):
build_graph(graph, sub_agent, highlight_pairs)
draw_edge(agent.name, sub_agent.name)
if isinstance(agent, LlmAgent):
for tool in agent.canonical_tools:
for tool in await agent.canonical_tools():
draw_node(tool)
draw_edge(agent.name, get_node_name(tool))
def get_agent_graph(root_agent, highlights_pairs, image=False):
async def get_agent_graph(root_agent, highlights_pairs, image=False):
print('build graph')
graph = graphviz.Digraph(graph_attr={'rankdir': 'LR', 'bgcolor': '#333537'})
build_graph(graph, root_agent, highlights_pairs)

View File

@ -24,13 +24,9 @@ import re
import sys
import traceback
import typing
from typing import Any
from typing import List
from typing import Literal
from typing import Optional
from typing import Any, List, Literal, Optional, Union
import click
from click import Tuple
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Query
@ -53,9 +49,10 @@ from pydantic import ValidationError
from starlette.types import Lifespan
from ..agents import RunConfig
from ..agents.base_agent import BaseAgent
from ..agents.live_request_queue import LiveRequest
from ..agents.live_request_queue import LiveRequestQueue
from ..agents.llm_agent import Agent
from ..agents.llm_agent import Agent, LlmAgent
from ..agents.run_config import StreamingMode
from ..artifacts import InMemoryArtifactService
from ..events.event import Event
@ -65,6 +62,7 @@ from ..sessions.database_session_service import DatabaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService
from ..sessions.session import Session
from ..sessions.vertex_ai_session_service import VertexAiSessionService
from ..tools.base_toolset import BaseToolset
from .cli_eval import EVAL_SESSION_ID_PREFIX
from .cli_eval import EvalMetric
from .cli_eval import EvalMetricResult
@ -163,6 +161,7 @@ def get_fast_api_app(
trace.set_tracer_provider(provider)
exit_stacks = []
toolsets_to_close: set[BaseToolset] = set()
@asynccontextmanager
async def internal_lifespan(app: FastAPI):
@ -173,6 +172,8 @@ def get_fast_api_app(
if exit_stacks:
for stack in exit_stacks:
await stack.aclose()
for toolset in toolsets_to_close:
await toolset.close()
else:
yield
@ -673,7 +674,7 @@ def get_fast_api_app(
from_name = event.author
to_name = function_call.name
function_call_highlights.append((from_name, to_name))
dot_graph = agent_graph.get_agent_graph(
dot_graph = await agent_graph.get_agent_graph(
root_agent, function_call_highlights
)
elif function_responses:
@ -682,13 +683,13 @@ def get_fast_api_app(
from_name = function_response.name
to_name = event.author
function_responses_highlights.append((from_name, to_name))
dot_graph = agent_graph.get_agent_graph(
dot_graph = await agent_graph.get_agent_graph(
root_agent, function_responses_highlights
)
else:
from_name = event.author
to_name = ""
dot_graph = agent_graph.get_agent_graph(
dot_graph = await agent_graph.get_agent_graph(
root_agent, [(from_name, to_name)]
)
if dot_graph and isinstance(dot_graph, graphviz.Digraph):
@ -766,6 +767,16 @@ def get_fast_api_app(
for task in pending:
task.cancel()
def _get_all_toolsets(agent: BaseAgent) -> set[BaseToolset]:
toolsets = set()
if isinstance(agent, LlmAgent):
for tool_union in agent.tools:
if isinstance(tool_union, BaseToolset):
toolsets.add(tool_union)
for sub_agent in agent.sub_agents:
toolsets.update(_get_all_toolsets(sub_agent))
return toolsets
async def _get_root_agent_async(app_name: str) -> Agent:
"""Returns the root agent for the given app."""
if app_name in root_agent_dict:
@ -786,6 +797,7 @@ def get_fast_api_app(
raise RuntimeError(f"error getting root agent, {e}") from e
root_agent_dict[app_name] = root_agent
toolsets_to_close.update(_get_all_toolsets(root_agent))
return root_agent
async def _get_runner_async(app_name: str) -> Runner:

View File

@ -258,7 +258,7 @@ class EvaluationGenerator:
if not isinstance(agent, Agent) and not isinstance(agent, LlmAgent):
return
for tool in agent.canonical_tools:
for tool in await agent.canonical_tools():
tool_name = tool.name
if tool_name in all_mock_tools:
agent.before_tool_callback = callback

View File

@ -29,6 +29,7 @@ from ...agents.base_agent import BaseAgent
from ...agents.callback_context import CallbackContext
from ...agents.invocation_context import InvocationContext
from ...agents.live_request_queue import LiveRequestQueue
from ...agents.readonly_context import ReadonlyContext
from ...agents.run_config import StreamingMode
from ...agents.transcription_entry import TranscriptionEntry
from ...events.event import Event
@ -296,7 +297,9 @@ class BaseLlmFlow(ABC):
yield event
# Run processors for tools.
for tool in agent.canonical_tools:
for tool in await agent.canonical_tools(
ReadonlyContext(invocation_context)
):
tool_context = ToolContext(invocation_context)
await tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request

View File

@ -18,9 +18,7 @@ import asyncio
import logging
import queue
import threading
from typing import AsyncGenerator
from typing import Generator
from typing import Optional
from typing import AsyncGenerator, Generator, Optional
from deprecated import deprecated
from google.genai import types
@ -391,7 +389,7 @@ class Runner:
f'CFC is not supported for model: {model_name} in agent:'
f' {self.agent.name}'
)
if built_in_code_execution not in self.agent.canonical_tools:
if built_in_code_execution not in self.agent.canonical_tools():
self.agent.tools.append(built_in_code_execution)
return InvocationContext(

View File

@ -0,0 +1,56 @@
from abc import ABC
from abc import abstractmethod
from typing import Protocol
from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.tools.base_tool import BaseTool
class ToolPredicate(Protocol):
"""Base class for a predicate that defines the interface to decide whether a
tool should be exposed to LLM. Toolset implementer could consider whether to
accept such instance in the toolset's constructor and apply the predicate in
get_tools method.
"""
def __call__(
self, tool: BaseTool, readonly_context: ReadonlyContext = None
) -> bool:
"""Decide whether the passed-in tool should be exposed to LLM based on the
current context. True if the tool is usable by the LLM.
It's used to filter tools in the toolset.
"""
class BaseToolset(ABC):
"""Base class for toolset.
A toolset is a collection of tools that can be used by an agent.
"""
@abstractmethod
async def get_tools(
self, readony_context: ReadonlyContext = None
) -> list[BaseTool]:
"""Return all tools in the toolset based on the provided context.
Args:
readony_context (ReadonlyContext, optional): Context used to filter tools
available to the agent. If None, all tools in the toolset are returned.
Returns:
list[BaseTool]: A list of tools available under the specified context.
"""
@abstractmethod
async def close(self) -> None:
"""Performs cleanup and releases resources held by the toolset.
NOTE: This method is invoked, for example, at the end of an agent server's
lifecycle or when the toolset is no longer needed. Implementations
should ensure that any open connections, files, or other managed
resources are properly released to prevent leaks.
"""