mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 19:32:21 -06:00
feat: extract agent loading logic from fast_api.py to a separate AgentLoader class and support more agent definition folder/file structure.
Structures supported: a) agents_dir/agent_name.py (with root_agent or agent.root_agent in it) b) agents_dir/agent_name_folder/__init__.py (with root_agent or agent.root_agent in the package) c) agents_dir/agent_name_folder/agent.py (where agent.py has root_agent) PiperOrigin-RevId: 763943716
This commit is contained in:
committed by
Copybara-Service
parent
16d9696012
commit
618c824994
@@ -17,11 +17,9 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import typing
|
||||
@@ -81,11 +79,11 @@ from .utils import common
|
||||
from .utils import create_empty_state
|
||||
from .utils import envs
|
||||
from .utils import evals
|
||||
from .utils.agent_loader import AgentLoader
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
_EVAL_SET_FILE_EXTENSION = ".evalset.json"
|
||||
_EVAL_SET_RESULT_FILE_EXTENSION = ".evalset_result.json"
|
||||
|
||||
|
||||
class ApiServerSpanExporter(export.SpanExporter):
|
||||
@@ -251,11 +249,7 @@ def get_fast_api_app(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if agents_dir not in sys.path:
|
||||
sys.path.append(agents_dir)
|
||||
|
||||
runner_dict = {}
|
||||
root_agent_dict = {}
|
||||
|
||||
# Build the Artifact service
|
||||
artifact_service = InMemoryArtifactService()
|
||||
@@ -282,6 +276,9 @@ def get_fast_api_app(
|
||||
else:
|
||||
session_service = InMemorySessionService()
|
||||
|
||||
# initialize Agent Loader
|
||||
agent_loader = AgentLoader(agents_dir)
|
||||
|
||||
@app.get("/list-apps")
|
||||
def list_apps() -> list[str]:
|
||||
base_path = Path.cwd() / agents_dir
|
||||
@@ -450,7 +447,7 @@ def get_fast_api_app(
|
||||
|
||||
# Populate the session with initial session state.
|
||||
initial_session_state = create_empty_state(
|
||||
await _get_root_agent_async(app_name)
|
||||
agent_loader.load_agent(app_name)
|
||||
)
|
||||
|
||||
new_eval_case = EvalCase(
|
||||
@@ -492,8 +489,6 @@ def get_fast_api_app(
|
||||
|
||||
# Create a mapping from eval set file to all the evals that needed to be
|
||||
# run.
|
||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
|
||||
|
||||
eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
|
||||
|
||||
if req.eval_ids:
|
||||
@@ -503,7 +498,7 @@ def get_fast_api_app(
|
||||
logger.info("Eval ids to run list is empty. We will run all eval cases.")
|
||||
eval_set_to_evals = {eval_set_id: eval_set.eval_cases}
|
||||
|
||||
root_agent = await _get_root_agent_async(app_name)
|
||||
root_agent = agent_loader.load_agent(app_name)
|
||||
run_eval_results = []
|
||||
eval_case_results = []
|
||||
async for eval_case_result in run_evals(
|
||||
@@ -741,7 +736,7 @@ def get_fast_api_app(
|
||||
|
||||
function_calls = event.get_function_calls()
|
||||
function_responses = event.get_function_responses()
|
||||
root_agent = await _get_root_agent_async(app_name)
|
||||
root_agent = agent_loader.load_agent(app_name)
|
||||
dot_graph = None
|
||||
if function_calls:
|
||||
function_call_highlights = []
|
||||
@@ -842,25 +837,12 @@ def get_fast_api_app(
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
async def _get_root_agent_async(app_name: str) -> Agent:
|
||||
"""Returns the root agent for the given app."""
|
||||
if app_name in root_agent_dict:
|
||||
return root_agent_dict[app_name]
|
||||
agent_module = importlib.import_module(app_name)
|
||||
if getattr(agent_module.agent, "root_agent"):
|
||||
root_agent = agent_module.agent.root_agent
|
||||
else:
|
||||
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
|
||||
|
||||
root_agent_dict[app_name] = root_agent
|
||||
return root_agent
|
||||
|
||||
async def _get_runner_async(app_name: str) -> Runner:
|
||||
"""Returns the runner for the given app."""
|
||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
|
||||
if app_name in runner_dict:
|
||||
return runner_dict[app_name]
|
||||
root_agent = await _get_root_agent_async(app_name)
|
||||
root_agent = agent_loader.load_agent(app_name)
|
||||
runner = Runner(
|
||||
app_name=agent_engine_id if agent_engine_id else app_name,
|
||||
agent=root_agent,
|
||||
|
||||
137
src/google/adk/cli/utils/agent_loader.py
Normal file
137
src/google/adk/cli/utils/agent_loader.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from . import envs
|
||||
from ...agents.base_agent import BaseAgent
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
class AgentLoader:
|
||||
"""Centralized agent loading with proper isolation, caching, and .env loading.
|
||||
Support loading agents from below folder/file structures:
|
||||
a) agents_dir/agent_name.py (with root_agent or agent.root_agent in it)
|
||||
b) agents_dir/agent_name_folder/__init__.py (with root_agent or agent.root_agent in the package)
|
||||
c) agents_dir/agent_name_folder/agent.py (where agent.py has root_agent)
|
||||
"""
|
||||
|
||||
def __init__(self, agents_dir: str):
|
||||
self.agents_dir = agents_dir.rstrip("/")
|
||||
self._original_sys_path = None
|
||||
self._agent_cache: dict[str, BaseAgent] = {}
|
||||
|
||||
def _load_from_module_or_package(self, agent_name: str) -> BaseAgent:
|
||||
# Load for case: Import "<agent_name>" (as a package or module)
|
||||
# Covers structures:
|
||||
# a) agents_dir/agent_name.py (with root_agent or agent.root_agent in it)
|
||||
# b) agents_dir/agent_name_folder/__init__.py (with root_agent or agent.root_agent in the package)
|
||||
try:
|
||||
module_candidate = importlib.import_module(agent_name)
|
||||
# Check for "root_agent" directly in "<agent_name>" module/package
|
||||
if hasattr(module_candidate, "root_agent"):
|
||||
logger.debug("Found root_agent directly in %s", agent_name)
|
||||
return module_candidate.root_agent
|
||||
# Check for "<agent_name>.agent.root_agent" structure (e.g. agent_name is a package,
|
||||
# and it has an 'agent' submodule/attribute which in turn has 'root_agent')
|
||||
if hasattr(module_candidate, "agent") and hasattr(
|
||||
module_candidate.agent, "root_agent"
|
||||
):
|
||||
logger.debug("Found root_agent in %s.agent attribute", agent_name)
|
||||
if isinstance(module_candidate.agent, BaseAgent):
|
||||
return module_candidate.agent.root_agent
|
||||
else:
|
||||
logger.warning(
|
||||
"Root agent found is not an instance of BaseAgent. But a type %s",
|
||||
type(module_candidate.agent),
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
logger.debug("Module %s itself not found.", agent_name)
|
||||
# Re-raise as ValueError to be caught by the final error message construction
|
||||
raise ValueError(
|
||||
f"Module {agent_name} not found during import attempts."
|
||||
) from None
|
||||
except ImportError as e:
|
||||
logger.warning("Error importing %s: %s", agent_name, e)
|
||||
|
||||
return None
|
||||
|
||||
def _load_from_submodule(self, agent_name: str) -> BaseAgent:
|
||||
# Load for case: Import "<agent_name>.agent" and look for "root_agent"
|
||||
# Covers structure: agents_dir/agent_name_folder/agent.py (where agent.py has root_agent)
|
||||
try:
|
||||
module_candidate = importlib.import_module(f"{agent_name}.agent")
|
||||
if hasattr(module_candidate, "root_agent"):
|
||||
logger.debug("Found root_agent in %s.agent", agent_name)
|
||||
if isinstance(module_candidate.root_agent, BaseAgent):
|
||||
return module_candidate.root_agent
|
||||
else:
|
||||
logger.warning(
|
||||
"Root agent found is not an instance of BaseAgent. But a type %s",
|
||||
type(module_candidate.root_agent),
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
logger.debug(
|
||||
"Module %s.agent not found, trying next pattern.", agent_name
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning("Error importing %s.agent: %s", agent_name, e)
|
||||
|
||||
return None
|
||||
|
||||
def _perform_load(self, agent_name: str) -> BaseAgent:
|
||||
"""Internal logic to load an agent"""
|
||||
# Add self.agents_dir to sys.path
|
||||
if self.agents_dir not in sys.path:
|
||||
sys.path.insert(0, self.agents_dir)
|
||||
|
||||
logger.debug(
|
||||
"Loading .env for agent %s from %s", agent_name, self.agents_dir
|
||||
)
|
||||
envs.load_dotenv_for_agent(agent_name, str(self.agents_dir))
|
||||
|
||||
root_agent = self._load_from_module_or_package(agent_name)
|
||||
if root_agent:
|
||||
return root_agent
|
||||
|
||||
root_agent = self._load_from_submodule(agent_name)
|
||||
if root_agent:
|
||||
return root_agent
|
||||
|
||||
# If no root_agent was found by any pattern
|
||||
raise ValueError(
|
||||
f"No root_agent found for '{agent_name}'. Searched in"
|
||||
f" '{agent_name}.agent.root_agent', '{agent_name}.root_agent', and"
|
||||
f" via an 'agent' attribute within the '{agent_name}' module/package."
|
||||
f" Ensure '{self.agents_dir}/{agent_name}' is structured correctly,"
|
||||
" an .env file can be loaded if present, and a root_agent is"
|
||||
" exposed."
|
||||
)
|
||||
|
||||
def load_agent(self, agent_name: str) -> BaseAgent:
|
||||
"""Load an agent module (with caching & .env) and return its root_agent (asynchronously)."""
|
||||
if agent_name in self._agent_cache:
|
||||
logger.debug("Returning cached agent for %s (async)", agent_name)
|
||||
return self._agent_cache[agent_name]
|
||||
|
||||
logger.debug("Loading agent %s - not in cache.", agent_name)
|
||||
# Assumes this method is called when the context manager (`with self:`) is active
|
||||
agent = self._perform_load(agent_name)
|
||||
self._agent_cache[agent_name] = agent
|
||||
return agent
|
||||
Reference in New Issue
Block a user