add toolset base class and allow llm agent to accept toolset as tools

PiperOrigin-RevId: 756605470
This commit is contained in:
Xiang (Sean) Zhou
2025-05-08 22:27:13 -07:00
committed by Copybara-Service
parent 8963300518
commit 4d7298e4f2
10 changed files with 129 additions and 39 deletions
+3 -3
View File
@@ -35,7 +35,7 @@ else:
retrieval_tool_module_loaded = True
def build_graph(graph, agent: BaseAgent, highlight_pairs):
async def build_graph(graph, agent: BaseAgent, highlight_pairs):
dark_green = '#0F5223'
light_green = '#69CB87'
light_gray = '#cccccc'
@@ -133,12 +133,12 @@ def build_graph(graph, agent: BaseAgent, highlight_pairs):
build_graph(graph, sub_agent, highlight_pairs)
draw_edge(agent.name, sub_agent.name)
if isinstance(agent, LlmAgent):
for tool in agent.canonical_tools:
for tool in await agent.canonical_tools():
draw_node(tool)
draw_edge(agent.name, get_node_name(tool))
def get_agent_graph(root_agent, highlights_pairs, image=False):
async def get_agent_graph(root_agent, highlights_pairs, image=False):
print('build graph')
graph = graphviz.Digraph(graph_attr={'rankdir': 'LR', 'bgcolor': '#333537'})
build_graph(graph, root_agent, highlights_pairs)
+21 -9
View File
@@ -24,13 +24,9 @@ import re
import sys
import traceback
import typing
from typing import Any
from typing import List
from typing import Literal
from typing import Optional
from typing import Any, List, Literal, Optional, Union
import click
from click import Tuple
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Query
@@ -53,9 +49,10 @@ from pydantic import ValidationError
from starlette.types import Lifespan
from ..agents import RunConfig
from ..agents.base_agent import BaseAgent
from ..agents.live_request_queue import LiveRequest
from ..agents.live_request_queue import LiveRequestQueue
from ..agents.llm_agent import Agent
from ..agents.llm_agent import Agent, LlmAgent
from ..agents.run_config import StreamingMode
from ..artifacts import InMemoryArtifactService
from ..events.event import Event
@@ -65,6 +62,7 @@ from ..sessions.database_session_service import DatabaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService
from ..sessions.session import Session
from ..sessions.vertex_ai_session_service import VertexAiSessionService
from ..tools.base_toolset import BaseToolset
from .cli_eval import EVAL_SESSION_ID_PREFIX
from .cli_eval import EvalMetric
from .cli_eval import EvalMetricResult
@@ -163,6 +161,7 @@ def get_fast_api_app(
trace.set_tracer_provider(provider)
exit_stacks = []
toolsets_to_close: set[BaseToolset] = set()
@asynccontextmanager
async def internal_lifespan(app: FastAPI):
@@ -173,6 +172,8 @@ def get_fast_api_app(
if exit_stacks:
for stack in exit_stacks:
await stack.aclose()
for toolset in toolsets_to_close:
await toolset.close()
else:
yield
@@ -673,7 +674,7 @@ def get_fast_api_app(
from_name = event.author
to_name = function_call.name
function_call_highlights.append((from_name, to_name))
dot_graph = agent_graph.get_agent_graph(
dot_graph = await agent_graph.get_agent_graph(
root_agent, function_call_highlights
)
elif function_responses:
@@ -682,13 +683,13 @@ def get_fast_api_app(
from_name = function_response.name
to_name = event.author
function_responses_highlights.append((from_name, to_name))
dot_graph = agent_graph.get_agent_graph(
dot_graph = await agent_graph.get_agent_graph(
root_agent, function_responses_highlights
)
else:
from_name = event.author
to_name = ""
dot_graph = agent_graph.get_agent_graph(
dot_graph = await agent_graph.get_agent_graph(
root_agent, [(from_name, to_name)]
)
if dot_graph and isinstance(dot_graph, graphviz.Digraph):
@@ -766,6 +767,16 @@ def get_fast_api_app(
for task in pending:
task.cancel()
def _get_all_toolsets(agent: BaseAgent) -> set[BaseToolset]:
toolsets = set()
if isinstance(agent, LlmAgent):
for tool_union in agent.tools:
if isinstance(tool_union, BaseToolset):
toolsets.add(tool_union)
for sub_agent in agent.sub_agents:
toolsets.update(_get_all_toolsets(sub_agent))
return toolsets
async def _get_root_agent_async(app_name: str) -> Agent:
"""Returns the root agent for the given app."""
if app_name in root_agent_dict:
@@ -786,6 +797,7 @@ def get_fast_api_app(
raise RuntimeError(f"error getting root agent, {e}") from e
root_agent_dict[app_name] = root_agent
toolsets_to_close.update(_get_all_toolsets(root_agent))
return root_agent
async def _get_runner_async(app_name: str) -> Runner: