mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
add toolset base class and allow llm agent to accept toolset as tools
PiperOrigin-RevId: 756605470
This commit is contained in:
parent
8963300518
commit
4d7298e4f2
@ -60,11 +60,6 @@ class CallbackContext(ReadonlyContext):
|
||||
"""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def user_content(self) -> Optional[types.Content]:
|
||||
"""The user content that started this invocation. READONLY field."""
|
||||
return self._invocation_context.user_content
|
||||
|
||||
async def load_artifact(
|
||||
self, filename: str, version: Optional[int] = None
|
||||
) -> Optional[types.Part]:
|
||||
|
@ -15,7 +15,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Optional, Union
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
@ -38,6 +46,7 @@ from ..models.llm_response import LlmResponse
|
||||
from ..models.registry import LLMRegistry
|
||||
from ..planners.base_planner import BasePlanner
|
||||
from ..tools.base_tool import BaseTool
|
||||
from ..tools.base_toolset import BaseToolset
|
||||
from ..tools.function_tool import FunctionTool
|
||||
from ..tools.tool_context import ToolContext
|
||||
from .base_agent import BaseAgent
|
||||
@ -89,18 +98,19 @@ AfterToolCallback: TypeAlias = Union[
|
||||
|
||||
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
|
||||
|
||||
ToolUnion: TypeAlias = Union[Callable, BaseTool]
|
||||
ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset]
|
||||
ExamplesUnion = Union[list[Example], BaseExampleProvider]
|
||||
|
||||
|
||||
def _convert_tool_union_to_tool(
|
||||
tool_union: ToolUnion,
|
||||
) -> BaseTool:
|
||||
return (
|
||||
tool_union
|
||||
if isinstance(tool_union, BaseTool)
|
||||
else FunctionTool(tool_union)
|
||||
)
|
||||
async def _convert_tool_union_to_tools(
|
||||
tool_union: ToolUnion, ctx: ReadonlyContext
|
||||
) -> list[BaseTool]:
|
||||
if isinstance(tool_union, BaseTool):
|
||||
return [tool_union]
|
||||
if isinstance(tool_union, Callable):
|
||||
return [FunctionTool(func=tool_union)]
|
||||
|
||||
return await tool_union.get_tools(ctx)
|
||||
|
||||
|
||||
class LlmAgent(BaseAgent):
|
||||
@ -312,13 +322,17 @@ class LlmAgent(BaseAgent):
|
||||
else:
|
||||
return self.global_instruction(ctx)
|
||||
|
||||
@property
|
||||
def canonical_tools(self) -> list[BaseTool]:
|
||||
"""The resolved self.tools field as a list of BaseTool.
|
||||
async def canonical_tools(
|
||||
self, ctx: ReadonlyContext = None
|
||||
) -> list[BaseTool]:
|
||||
"""The resolved self.tools field as a list of BaseTool based on the context.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
return [_convert_tool_union_to_tool(tool) for tool in self.tools]
|
||||
resolved_tools = []
|
||||
for tool_union in self.tools:
|
||||
resolved_tools.extend(await _convert_tool_union_to_tools(tool_union, ctx))
|
||||
return resolved_tools
|
||||
|
||||
@property
|
||||
def canonical_before_model_callbacks(
|
||||
|
@ -15,10 +15,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.genai import types
|
||||
from .invocation_context import InvocationContext
|
||||
|
||||
|
||||
@ -30,6 +31,11 @@ class ReadonlyContext:
|
||||
) -> None:
|
||||
self._invocation_context = invocation_context
|
||||
|
||||
@property
|
||||
def user_content(self) -> Optional[types.Content]:
|
||||
"""The user content that started this invocation. READONLY field."""
|
||||
return self._invocation_context.user_content
|
||||
|
||||
@property
|
||||
def invocation_id(self) -> str:
|
||||
"""The current invocation id."""
|
||||
|
@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from ..agents.readonly_context import ReadonlyContext
|
||||
from ..events.event import Event
|
||||
from ..flows.llm_flows import functions
|
||||
from ..flows.llm_flows._base_llm_processor import BaseLlmRequestProcessor
|
||||
@ -105,7 +106,12 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
function_response_event = await functions.handle_function_calls_async(
|
||||
invocation_context,
|
||||
event,
|
||||
{tool.name: tool for tool in agent.canonical_tools},
|
||||
{
|
||||
tool.name: tool
|
||||
for tool in await agent.canonical_tools(
|
||||
ReadonlyContext(invocation_context)
|
||||
)
|
||||
},
|
||||
# there could be parallel function calls that require auth
|
||||
# auth response would be a dict keyed by function call id
|
||||
tools_to_resume,
|
||||
|
@ -35,7 +35,7 @@ else:
|
||||
retrieval_tool_module_loaded = True
|
||||
|
||||
|
||||
def build_graph(graph, agent: BaseAgent, highlight_pairs):
|
||||
async def build_graph(graph, agent: BaseAgent, highlight_pairs):
|
||||
dark_green = '#0F5223'
|
||||
light_green = '#69CB87'
|
||||
light_gray = '#cccccc'
|
||||
@ -133,12 +133,12 @@ def build_graph(graph, agent: BaseAgent, highlight_pairs):
|
||||
build_graph(graph, sub_agent, highlight_pairs)
|
||||
draw_edge(agent.name, sub_agent.name)
|
||||
if isinstance(agent, LlmAgent):
|
||||
for tool in agent.canonical_tools:
|
||||
for tool in await agent.canonical_tools():
|
||||
draw_node(tool)
|
||||
draw_edge(agent.name, get_node_name(tool))
|
||||
|
||||
|
||||
def get_agent_graph(root_agent, highlights_pairs, image=False):
|
||||
async def get_agent_graph(root_agent, highlights_pairs, image=False):
|
||||
print('build graph')
|
||||
graph = graphviz.Digraph(graph_attr={'rankdir': 'LR', 'bgcolor': '#333537'})
|
||||
build_graph(graph, root_agent, highlights_pairs)
|
||||
|
@ -24,13 +24,9 @@ import re
|
||||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
import click
|
||||
from click import Tuple
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
@ -53,9 +49,10 @@ from pydantic import ValidationError
|
||||
from starlette.types import Lifespan
|
||||
|
||||
from ..agents import RunConfig
|
||||
from ..agents.base_agent import BaseAgent
|
||||
from ..agents.live_request_queue import LiveRequest
|
||||
from ..agents.live_request_queue import LiveRequestQueue
|
||||
from ..agents.llm_agent import Agent
|
||||
from ..agents.llm_agent import Agent, LlmAgent
|
||||
from ..agents.run_config import StreamingMode
|
||||
from ..artifacts import InMemoryArtifactService
|
||||
from ..events.event import Event
|
||||
@ -65,6 +62,7 @@ from ..sessions.database_session_service import DatabaseSessionService
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from ..sessions.session import Session
|
||||
from ..sessions.vertex_ai_session_service import VertexAiSessionService
|
||||
from ..tools.base_toolset import BaseToolset
|
||||
from .cli_eval import EVAL_SESSION_ID_PREFIX
|
||||
from .cli_eval import EvalMetric
|
||||
from .cli_eval import EvalMetricResult
|
||||
@ -163,6 +161,7 @@ def get_fast_api_app(
|
||||
trace.set_tracer_provider(provider)
|
||||
|
||||
exit_stacks = []
|
||||
toolsets_to_close: set[BaseToolset] = set()
|
||||
|
||||
@asynccontextmanager
|
||||
async def internal_lifespan(app: FastAPI):
|
||||
@ -173,6 +172,8 @@ def get_fast_api_app(
|
||||
if exit_stacks:
|
||||
for stack in exit_stacks:
|
||||
await stack.aclose()
|
||||
for toolset in toolsets_to_close:
|
||||
await toolset.close()
|
||||
else:
|
||||
yield
|
||||
|
||||
@ -673,7 +674,7 @@ def get_fast_api_app(
|
||||
from_name = event.author
|
||||
to_name = function_call.name
|
||||
function_call_highlights.append((from_name, to_name))
|
||||
dot_graph = agent_graph.get_agent_graph(
|
||||
dot_graph = await agent_graph.get_agent_graph(
|
||||
root_agent, function_call_highlights
|
||||
)
|
||||
elif function_responses:
|
||||
@ -682,13 +683,13 @@ def get_fast_api_app(
|
||||
from_name = function_response.name
|
||||
to_name = event.author
|
||||
function_responses_highlights.append((from_name, to_name))
|
||||
dot_graph = agent_graph.get_agent_graph(
|
||||
dot_graph = await agent_graph.get_agent_graph(
|
||||
root_agent, function_responses_highlights
|
||||
)
|
||||
else:
|
||||
from_name = event.author
|
||||
to_name = ""
|
||||
dot_graph = agent_graph.get_agent_graph(
|
||||
dot_graph = await agent_graph.get_agent_graph(
|
||||
root_agent, [(from_name, to_name)]
|
||||
)
|
||||
if dot_graph and isinstance(dot_graph, graphviz.Digraph):
|
||||
@ -766,6 +767,16 @@ def get_fast_api_app(
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
def _get_all_toolsets(agent: BaseAgent) -> set[BaseToolset]:
|
||||
toolsets = set()
|
||||
if isinstance(agent, LlmAgent):
|
||||
for tool_union in agent.tools:
|
||||
if isinstance(tool_union, BaseToolset):
|
||||
toolsets.add(tool_union)
|
||||
for sub_agent in agent.sub_agents:
|
||||
toolsets.update(_get_all_toolsets(sub_agent))
|
||||
return toolsets
|
||||
|
||||
async def _get_root_agent_async(app_name: str) -> Agent:
|
||||
"""Returns the root agent for the given app."""
|
||||
if app_name in root_agent_dict:
|
||||
@ -786,6 +797,7 @@ def get_fast_api_app(
|
||||
raise RuntimeError(f"error getting root agent, {e}") from e
|
||||
|
||||
root_agent_dict[app_name] = root_agent
|
||||
toolsets_to_close.update(_get_all_toolsets(root_agent))
|
||||
return root_agent
|
||||
|
||||
async def _get_runner_async(app_name: str) -> Runner:
|
||||
|
@ -258,7 +258,7 @@ class EvaluationGenerator:
|
||||
if not isinstance(agent, Agent) and not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
for tool in agent.canonical_tools:
|
||||
for tool in await agent.canonical_tools():
|
||||
tool_name = tool.name
|
||||
if tool_name in all_mock_tools:
|
||||
agent.before_tool_callback = callback
|
||||
|
@ -29,6 +29,7 @@ from ...agents.base_agent import BaseAgent
|
||||
from ...agents.callback_context import CallbackContext
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...agents.live_request_queue import LiveRequestQueue
|
||||
from ...agents.readonly_context import ReadonlyContext
|
||||
from ...agents.run_config import StreamingMode
|
||||
from ...agents.transcription_entry import TranscriptionEntry
|
||||
from ...events.event import Event
|
||||
@ -296,7 +297,9 @@ class BaseLlmFlow(ABC):
|
||||
yield event
|
||||
|
||||
# Run processors for tools.
|
||||
for tool in agent.canonical_tools:
|
||||
for tool in await agent.canonical_tools(
|
||||
ReadonlyContext(invocation_context)
|
||||
):
|
||||
tool_context = ToolContext(invocation_context)
|
||||
await tool.process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
|
@ -18,9 +18,7 @@ import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import AsyncGenerator
|
||||
from typing import Generator
|
||||
from typing import Optional
|
||||
from typing import AsyncGenerator, Generator, Optional
|
||||
|
||||
from deprecated import deprecated
|
||||
from google.genai import types
|
||||
@ -391,7 +389,7 @@ class Runner:
|
||||
f'CFC is not supported for model: {model_name} in agent:'
|
||||
f' {self.agent.name}'
|
||||
)
|
||||
if built_in_code_execution not in self.agent.canonical_tools:
|
||||
if built_in_code_execution not in self.agent.canonical_tools():
|
||||
self.agent.tools.append(built_in_code_execution)
|
||||
|
||||
return InvocationContext(
|
||||
|
56
src/google/adk/tools/base_toolset.py
Normal file
56
src/google/adk/tools/base_toolset.py
Normal file
@ -0,0 +1,56 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Protocol
|
||||
|
||||
from google.adk.agents.readonly_context import ReadonlyContext
|
||||
from google.adk.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class ToolPredicate(Protocol):
|
||||
"""Base class for a predicate that defines the interface to decide whether a
|
||||
|
||||
tool should be exposed to LLM. Toolset implementer could consider whether to
|
||||
accept such instance in the toolset's constructor and apply the predicate in
|
||||
get_tools method.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, tool: BaseTool, readonly_context: ReadonlyContext = None
|
||||
) -> bool:
|
||||
"""Decide whether the passed-in tool should be exposed to LLM based on the
|
||||
|
||||
current context. True if the tool is usable by the LLM.
|
||||
|
||||
It's used to filter tools in the toolset.
|
||||
"""
|
||||
|
||||
|
||||
class BaseToolset(ABC):
|
||||
"""Base class for toolset.
|
||||
|
||||
A toolset is a collection of tools that can be used by an agent.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_tools(
|
||||
self, readony_context: ReadonlyContext = None
|
||||
) -> list[BaseTool]:
|
||||
"""Return all tools in the toolset based on the provided context.
|
||||
|
||||
Args:
|
||||
readony_context (ReadonlyContext, optional): Context used to filter tools
|
||||
available to the agent. If None, all tools in the toolset are returned.
|
||||
|
||||
Returns:
|
||||
list[BaseTool]: A list of tools available under the specified context.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Performs cleanup and releases resources held by the toolset.
|
||||
|
||||
NOTE: This method is invoked, for example, at the end of an agent server's
|
||||
lifecycle or when the toolset is no longer needed. Implementations
|
||||
should ensure that any open connections, files, or other managed
|
||||
resources are properly released to prevent leaks.
|
||||
"""
|
Loading…
Reference in New Issue
Block a user