mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2026-02-04 13:56:24 -06:00
add toolset base class and allow llm agent to accept toolset as tools
PiperOrigin-RevId: 756605470
This commit is contained in:
committed by
Copybara-Service
parent
8963300518
commit
4d7298e4f2
@@ -35,7 +35,7 @@ else:
|
||||
retrieval_tool_module_loaded = True
|
||||
|
||||
|
||||
def build_graph(graph, agent: BaseAgent, highlight_pairs):
|
||||
async def build_graph(graph, agent: BaseAgent, highlight_pairs):
|
||||
dark_green = '#0F5223'
|
||||
light_green = '#69CB87'
|
||||
light_gray = '#cccccc'
|
||||
@@ -133,12 +133,12 @@ def build_graph(graph, agent: BaseAgent, highlight_pairs):
|
||||
build_graph(graph, sub_agent, highlight_pairs)
|
||||
draw_edge(agent.name, sub_agent.name)
|
||||
if isinstance(agent, LlmAgent):
|
||||
for tool in agent.canonical_tools:
|
||||
for tool in await agent.canonical_tools():
|
||||
draw_node(tool)
|
||||
draw_edge(agent.name, get_node_name(tool))
|
||||
|
||||
|
||||
def get_agent_graph(root_agent, highlights_pairs, image=False):
|
||||
async def get_agent_graph(root_agent, highlights_pairs, image=False):
|
||||
print('build graph')
|
||||
graph = graphviz.Digraph(graph_attr={'rankdir': 'LR', 'bgcolor': '#333537'})
|
||||
build_graph(graph, root_agent, highlights_pairs)
|
||||
|
||||
@@ -24,13 +24,9 @@ import re
|
||||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
import click
|
||||
from click import Tuple
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
@@ -53,9 +49,10 @@ from pydantic import ValidationError
|
||||
from starlette.types import Lifespan
|
||||
|
||||
from ..agents import RunConfig
|
||||
from ..agents.base_agent import BaseAgent
|
||||
from ..agents.live_request_queue import LiveRequest
|
||||
from ..agents.live_request_queue import LiveRequestQueue
|
||||
from ..agents.llm_agent import Agent
|
||||
from ..agents.llm_agent import Agent, LlmAgent
|
||||
from ..agents.run_config import StreamingMode
|
||||
from ..artifacts import InMemoryArtifactService
|
||||
from ..events.event import Event
|
||||
@@ -65,6 +62,7 @@ from ..sessions.database_session_service import DatabaseSessionService
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from ..sessions.session import Session
|
||||
from ..sessions.vertex_ai_session_service import VertexAiSessionService
|
||||
from ..tools.base_toolset import BaseToolset
|
||||
from .cli_eval import EVAL_SESSION_ID_PREFIX
|
||||
from .cli_eval import EvalMetric
|
||||
from .cli_eval import EvalMetricResult
|
||||
@@ -163,6 +161,7 @@ def get_fast_api_app(
|
||||
trace.set_tracer_provider(provider)
|
||||
|
||||
exit_stacks = []
|
||||
toolsets_to_close: set[BaseToolset] = set()
|
||||
|
||||
@asynccontextmanager
|
||||
async def internal_lifespan(app: FastAPI):
|
||||
@@ -173,6 +172,8 @@ def get_fast_api_app(
|
||||
if exit_stacks:
|
||||
for stack in exit_stacks:
|
||||
await stack.aclose()
|
||||
for toolset in toolsets_to_close:
|
||||
await toolset.close()
|
||||
else:
|
||||
yield
|
||||
|
||||
@@ -673,7 +674,7 @@ def get_fast_api_app(
|
||||
from_name = event.author
|
||||
to_name = function_call.name
|
||||
function_call_highlights.append((from_name, to_name))
|
||||
dot_graph = agent_graph.get_agent_graph(
|
||||
dot_graph = await agent_graph.get_agent_graph(
|
||||
root_agent, function_call_highlights
|
||||
)
|
||||
elif function_responses:
|
||||
@@ -682,13 +683,13 @@ def get_fast_api_app(
|
||||
from_name = function_response.name
|
||||
to_name = event.author
|
||||
function_responses_highlights.append((from_name, to_name))
|
||||
dot_graph = agent_graph.get_agent_graph(
|
||||
dot_graph = await agent_graph.get_agent_graph(
|
||||
root_agent, function_responses_highlights
|
||||
)
|
||||
else:
|
||||
from_name = event.author
|
||||
to_name = ""
|
||||
dot_graph = agent_graph.get_agent_graph(
|
||||
dot_graph = await agent_graph.get_agent_graph(
|
||||
root_agent, [(from_name, to_name)]
|
||||
)
|
||||
if dot_graph and isinstance(dot_graph, graphviz.Digraph):
|
||||
@@ -766,6 +767,16 @@ def get_fast_api_app(
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
def _get_all_toolsets(agent: BaseAgent) -> set[BaseToolset]:
|
||||
toolsets = set()
|
||||
if isinstance(agent, LlmAgent):
|
||||
for tool_union in agent.tools:
|
||||
if isinstance(tool_union, BaseToolset):
|
||||
toolsets.add(tool_union)
|
||||
for sub_agent in agent.sub_agents:
|
||||
toolsets.update(_get_all_toolsets(sub_agent))
|
||||
return toolsets
|
||||
|
||||
async def _get_root_agent_async(app_name: str) -> Agent:
|
||||
"""Returns the root agent for the given app."""
|
||||
if app_name in root_agent_dict:
|
||||
@@ -786,6 +797,7 @@ def get_fast_api_app(
|
||||
raise RuntimeError(f"error getting root agent, {e}") from e
|
||||
|
||||
root_agent_dict[app_name] = root_agent
|
||||
toolsets_to_close.update(_get_all_toolsets(root_agent))
|
||||
return root_agent
|
||||
|
||||
async def _get_runner_async(app_name: str) -> Runner:
|
||||
|
||||
Reference in New Issue
Block a user