feat: extract agent loading logic from fast_api.py to a separate AgentLoader class and support more agent definition folder/file structure.

Structures supported:
a) agents_dir/agent_name.py (with root_agent or agent.root_agent in it)
b) agents_dir/agent_name_folder/__init__.py (with root_agent or agent.root_agent in the package)
c) agents_dir/agent_name_folder/agent.py (where agent.py has root_agent)

PiperOrigin-RevId: 763943716
This commit is contained in:
Xiang (Sean) Zhou 2025-05-27 14:14:00 -07:00 committed by Copybara-Service
parent 16d9696012
commit 618c824994
3 changed files with 448 additions and 26 deletions

View File

@ -17,11 +17,9 @@ from __future__ import annotations
import asyncio import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import importlib
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
import sys
import time import time
import traceback import traceback
import typing import typing
@ -81,11 +79,11 @@ from .utils import common
from .utils import create_empty_state from .utils import create_empty_state
from .utils import envs from .utils import envs
from .utils import evals from .utils import evals
from .utils.agent_loader import AgentLoader
logger = logging.getLogger("google_adk." + __name__) logger = logging.getLogger("google_adk." + __name__)
_EVAL_SET_FILE_EXTENSION = ".evalset.json" _EVAL_SET_FILE_EXTENSION = ".evalset.json"
_EVAL_SET_RESULT_FILE_EXTENSION = ".evalset_result.json"
class ApiServerSpanExporter(export.SpanExporter): class ApiServerSpanExporter(export.SpanExporter):
@ -251,11 +249,7 @@ def get_fast_api_app(
allow_headers=["*"], allow_headers=["*"],
) )
if agents_dir not in sys.path:
sys.path.append(agents_dir)
runner_dict = {} runner_dict = {}
root_agent_dict = {}
# Build the Artifact service # Build the Artifact service
artifact_service = InMemoryArtifactService() artifact_service = InMemoryArtifactService()
@ -282,6 +276,9 @@ def get_fast_api_app(
else: else:
session_service = InMemorySessionService() session_service = InMemorySessionService()
# initialize Agent Loader
agent_loader = AgentLoader(agents_dir)
@app.get("/list-apps") @app.get("/list-apps")
def list_apps() -> list[str]: def list_apps() -> list[str]:
base_path = Path.cwd() / agents_dir base_path = Path.cwd() / agents_dir
@ -450,7 +447,7 @@ def get_fast_api_app(
# Populate the session with initial session state. # Populate the session with initial session state.
initial_session_state = create_empty_state( initial_session_state = create_empty_state(
await _get_root_agent_async(app_name) agent_loader.load_agent(app_name)
) )
new_eval_case = EvalCase( new_eval_case = EvalCase(
@ -492,8 +489,6 @@ def get_fast_api_app(
# Create a mapping from eval set file to all the evals that needed to be # Create a mapping from eval set file to all the evals that needed to be
# run. # run.
envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id) eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
if req.eval_ids: if req.eval_ids:
@ -503,7 +498,7 @@ def get_fast_api_app(
logger.info("Eval ids to run list is empty. We will run all eval cases.") logger.info("Eval ids to run list is empty. We will run all eval cases.")
eval_set_to_evals = {eval_set_id: eval_set.eval_cases} eval_set_to_evals = {eval_set_id: eval_set.eval_cases}
root_agent = await _get_root_agent_async(app_name) root_agent = agent_loader.load_agent(app_name)
run_eval_results = [] run_eval_results = []
eval_case_results = [] eval_case_results = []
async for eval_case_result in run_evals( async for eval_case_result in run_evals(
@ -741,7 +736,7 @@ def get_fast_api_app(
function_calls = event.get_function_calls() function_calls = event.get_function_calls()
function_responses = event.get_function_responses() function_responses = event.get_function_responses()
root_agent = await _get_root_agent_async(app_name) root_agent = agent_loader.load_agent(app_name)
dot_graph = None dot_graph = None
if function_calls: if function_calls:
function_call_highlights = [] function_call_highlights = []
@ -842,25 +837,12 @@ def get_fast_api_app(
for task in pending: for task in pending:
task.cancel() task.cancel()
async def _get_root_agent_async(app_name: str) -> Agent:
"""Returns the root agent for the given app."""
if app_name in root_agent_dict:
return root_agent_dict[app_name]
agent_module = importlib.import_module(app_name)
if getattr(agent_module.agent, "root_agent"):
root_agent = agent_module.agent.root_agent
else:
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
root_agent_dict[app_name] = root_agent
return root_agent
async def _get_runner_async(app_name: str) -> Runner: async def _get_runner_async(app_name: str) -> Runner:
"""Returns the runner for the given app.""" """Returns the runner for the given app."""
envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir) envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
if app_name in runner_dict: if app_name in runner_dict:
return runner_dict[app_name] return runner_dict[app_name]
root_agent = await _get_root_agent_async(app_name) root_agent = agent_loader.load_agent(app_name)
runner = Runner( runner = Runner(
app_name=agent_engine_id if agent_engine_id else app_name, app_name=agent_engine_id if agent_engine_id else app_name,
agent=root_agent, agent=root_agent,

View File

@ -0,0 +1,137 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import importlib
import logging
import sys
from . import envs
from ...agents.base_agent import BaseAgent
logger = logging.getLogger("google_adk." + __name__)
class AgentLoader:
"""Centralized agent loading with proper isolation, caching, and .env loading.
Support loading agents from below folder/file structures:
a) agents_dir/agent_name.py (with root_agent or agent.root_agent in it)
b) agents_dir/agent_name_folder/__init__.py (with root_agent or agent.root_agent in the package)
c) agents_dir/agent_name_folder/agent.py (where agent.py has root_agent)
"""
def __init__(self, agents_dir: str):
self.agents_dir = agents_dir.rstrip("/")
self._original_sys_path = None
self._agent_cache: dict[str, BaseAgent] = {}
def _load_from_module_or_package(self, agent_name: str) -> BaseAgent:
# Load for case: Import "<agent_name>" (as a package or module)
# Covers structures:
# a) agents_dir/agent_name.py (with root_agent or agent.root_agent in it)
# b) agents_dir/agent_name_folder/__init__.py (with root_agent or agent.root_agent in the package)
try:
module_candidate = importlib.import_module(agent_name)
# Check for "root_agent" directly in "<agent_name>" module/package
if hasattr(module_candidate, "root_agent"):
logger.debug("Found root_agent directly in %s", agent_name)
return module_candidate.root_agent
# Check for "<agent_name>.agent.root_agent" structure (e.g. agent_name is a package,
# and it has an 'agent' submodule/attribute which in turn has 'root_agent')
if hasattr(module_candidate, "agent") and hasattr(
module_candidate.agent, "root_agent"
):
logger.debug("Found root_agent in %s.agent attribute", agent_name)
if isinstance(module_candidate.agent, BaseAgent):
return module_candidate.agent.root_agent
else:
logger.warning(
"Root agent found is not an instance of BaseAgent. But a type %s",
type(module_candidate.agent),
)
except ModuleNotFoundError:
logger.debug("Module %s itself not found.", agent_name)
# Re-raise as ValueError to be caught by the final error message construction
raise ValueError(
f"Module {agent_name} not found during import attempts."
) from None
except ImportError as e:
logger.warning("Error importing %s: %s", agent_name, e)
return None
def _load_from_submodule(self, agent_name: str) -> BaseAgent:
# Load for case: Import "<agent_name>.agent" and look for "root_agent"
# Covers structure: agents_dir/agent_name_folder/agent.py (where agent.py has root_agent)
try:
module_candidate = importlib.import_module(f"{agent_name}.agent")
if hasattr(module_candidate, "root_agent"):
logger.debug("Found root_agent in %s.agent", agent_name)
if isinstance(module_candidate.root_agent, BaseAgent):
return module_candidate.root_agent
else:
logger.warning(
"Root agent found is not an instance of BaseAgent. But a type %s",
type(module_candidate.root_agent),
)
except ModuleNotFoundError:
logger.debug(
"Module %s.agent not found, trying next pattern.", agent_name
)
except ImportError as e:
logger.warning("Error importing %s.agent: %s", agent_name, e)
return None
def _perform_load(self, agent_name: str) -> BaseAgent:
"""Internal logic to load an agent"""
# Add self.agents_dir to sys.path
if self.agents_dir not in sys.path:
sys.path.insert(0, self.agents_dir)
logger.debug(
"Loading .env for agent %s from %s", agent_name, self.agents_dir
)
envs.load_dotenv_for_agent(agent_name, str(self.agents_dir))
root_agent = self._load_from_module_or_package(agent_name)
if root_agent:
return root_agent
root_agent = self._load_from_submodule(agent_name)
if root_agent:
return root_agent
# If no root_agent was found by any pattern
raise ValueError(
f"No root_agent found for '{agent_name}'. Searched in"
f" '{agent_name}.agent.root_agent', '{agent_name}.root_agent', and"
f" via an 'agent' attribute within the '{agent_name}' module/package."
f" Ensure '{self.agents_dir}/{agent_name}' is structured correctly,"
" an .env file can be loaded if present, and a root_agent is"
" exposed."
)
def load_agent(self, agent_name: str) -> BaseAgent:
"""Load an agent module (with caching & .env) and return its root_agent (asynchronously)."""
if agent_name in self._agent_cache:
logger.debug("Returning cached agent for %s (async)", agent_name)
return self._agent_cache[agent_name]
logger.debug("Loading agent %s - not in cache.", agent_name)
# Assumes this method is called when the context manager (`with self:`) is active
agent = self._perform_load(agent_name)
self._agent_cache[agent_name] = agent
return agent

View File

@ -0,0 +1,303 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
import sys
import tempfile
from textwrap import dedent
from google.adk.cli.utils.agent_loader import AgentLoader
import pytest
class TestAgentLoader:
"""Unit tests for AgentLoader focusing on interface behavior."""
@pytest.fixture(autouse=True)
def cleanup_sys_path(self):
"""Ensure sys.path is restored after each test."""
original_path = sys.path.copy()
original_env = os.environ.copy()
yield
sys.path[:] = original_path
# Restore environment variables
os.environ.clear()
os.environ.update(original_env)
def create_agent_structure(
self, temp_dir: Path, agent_name: str, structure_type: str
):
"""Create different agent structures for testing.
Args:
temp_dir: The temporary directory to create the agent in
agent_name: Name of the agent
structure_type: One of 'module', 'package_with_root', 'package_with_agent_module'
"""
if structure_type == "module":
# Structure: agents_dir/agent_name.py
agent_file = temp_dir / f"{agent_name}.py"
agent_file.write_text(dedent(f"""
import os
from google.adk.agents.base_agent import BaseAgent
from typing import Any
class {agent_name.title()}Agent(BaseAgent):
agent_id: Any = None
config: Any = None
def __init__(self):
super().__init__(name="{agent_name}")
self.agent_id = id(self)
self.config = os.environ.get("AGENT_CONFIG", "default")
root_agent = {agent_name.title()}Agent()
"""))
elif structure_type == "package_with_root":
# Structure: agents_dir/agent_name/__init__.py (with root_agent)
agent_dir = temp_dir / agent_name
agent_dir.mkdir()
init_file = agent_dir / "__init__.py"
init_file.write_text(dedent(f"""
import os
from google.adk.agents.base_agent import BaseAgent
from typing import Any
class {agent_name.title()}Agent(BaseAgent):
agent_id: Any = None
config: Any = None
def __init__(self):
super().__init__(name="{agent_name}")
self.agent_id = id(self)
self.config = os.environ.get("AGENT_CONFIG", "default")
root_agent = {agent_name.title()}Agent()
"""))
elif structure_type == "package_with_agent_module":
# Structure: agents_dir/agent_name/agent.py
agent_dir = temp_dir / agent_name
agent_dir.mkdir()
# Create __init__.py
init_file = agent_dir / "__init__.py"
init_file.write_text("")
# Create agent.py with root_agent
agent_file = agent_dir / "agent.py"
agent_file.write_text(dedent(f"""
import os
from google.adk.agents.base_agent import BaseAgent
from typing import Any
class {agent_name.title()}Agent(BaseAgent):
agent_id: Any = None
config: Any = None
def __init__(self):
super().__init__(name="{agent_name}")
self.agent_id = id(self)
self.config = os.environ.get("AGENT_CONFIG", "default")
root_agent = {agent_name.title()}Agent()
"""))
def create_env_file(self, temp_dir: Path, agent_name: str, env_vars: dict):
"""Create a .env file for the agent."""
env_file = temp_dir / agent_name / ".env"
env_file.parent.mkdir(exist_ok=True)
env_content = "\n".join(
[f"{key}={value}" for key, value in env_vars.items()]
)
env_file.write_text(env_content)
def test_load_agent_as_module(self):
"""Test loading an agent structured as a single module file."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Create agent as module
self.create_agent_structure(temp_path, "module_agent", "module")
# Load the agent
loader = AgentLoader(str(temp_path))
agent = loader.load_agent("module_agent")
# Assert agent was loaded correctly
assert agent.name == "module_agent"
assert hasattr(agent, "agent_id")
assert agent.config == "default"
def test_load_agent_as_package_with_root_agent(self):
"""Test loading an agent structured as a package with root_agent in __init__.py."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Create agent as package
self.create_agent_structure(
temp_path, "package_agent", "package_with_root"
)
# Load the agent
loader = AgentLoader(str(temp_path))
agent = loader.load_agent("package_agent")
# Assert agent was loaded correctly
assert agent.name == "package_agent"
assert hasattr(agent, "agent_id")
def test_load_agent_as_package_with_agent_module(self):
"""Test loading an agent structured as a package with separate agent.py module."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Create agent as package with agent.py
self.create_agent_structure(
temp_path, "modular_agent", "package_with_agent_module"
)
# Load the agent
loader = AgentLoader(str(temp_path))
agent = loader.load_agent("modular_agent")
# Assert agent was loaded correctly
assert agent.name == "modular_agent"
assert hasattr(agent, "agent_id")
def test_agent_caching_returns_same_instance(self):
"""Test that loading the same agent twice returns the same instance."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Create agent
self.create_agent_structure(temp_path, "cached_agent", "module")
# Load the agent twice
loader = AgentLoader(str(temp_path))
agent1 = loader.load_agent("cached_agent")
agent2 = loader.load_agent("cached_agent")
# Assert same instance is returned
assert agent1 is agent2
assert agent1.agent_id == agent2.agent_id
def test_env_loading_for_agent(self):
"""Test that .env file is loaded for the agent."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Create agent and .env file
self.create_agent_structure(temp_path, "env_agent", "package_with_root")
self.create_env_file(
temp_path,
"env_agent",
{"AGENT_CONFIG": "production", "AGENT_SECRET": "test_secret_123"},
)
# Load the agent
loader = AgentLoader(str(temp_path))
agent = loader.load_agent("env_agent")
# Assert environment variables were loaded
assert agent.config == "production"
assert os.environ.get("AGENT_SECRET") == "test_secret_123"
def test_load_multiple_different_agents(self):
"""Test loading multiple different agents."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Create multiple agents with different structures
self.create_agent_structure(temp_path, "agent_one", "module")
self.create_agent_structure(temp_path, "agent_two", "package_with_root")
self.create_agent_structure(
temp_path, "agent_three", "package_with_agent_module"
)
# Load all agents
loader = AgentLoader(str(temp_path))
agent1 = loader.load_agent("agent_one")
agent2 = loader.load_agent("agent_two")
agent3 = loader.load_agent("agent_three")
# Assert all agents were loaded correctly and are different instances
assert agent1.name == "agent_one"
assert agent2.name == "agent_two"
assert agent3.name == "agent_three"
assert agent1 is not agent2
assert agent2 is not agent3
assert agent1.agent_id != agent2.agent_id != agent3.agent_id
def test_agent_not_found_error(self):
"""Test that appropriate error is raised when agent is not found."""
with tempfile.TemporaryDirectory() as temp_dir:
loader = AgentLoader(temp_dir)
# Try to load non-existent agent
with pytest.raises(ValueError) as exc_info:
loader.load_agent("nonexistent_agent")
assert "Module nonexistent_agent not found" in str(exc_info.value)
def test_agent_without_root_agent_error(self):
"""Test that appropriate error is raised when agent has no root_agent."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Create agent without root_agent
agent_file = temp_path / "broken_agent.py"
agent_file.write_text(dedent("""
class BrokenAgent:
def __init__(self):
self.name = "broken"
# Note: No root_agent defined
"""))
loader = AgentLoader(str(temp_path))
# Try to load agent without root_agent
with pytest.raises(ValueError) as exc_info:
loader.load_agent("broken_agent")
assert "No root_agent found for 'broken_agent'" in str(exc_info.value)
def test_sys_path_modification(self):
"""Test that agents_dir is added to sys.path correctly."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Create agent
self.create_agent_structure(temp_path, "path_agent", "module")
# Check sys.path before
assert str(temp_path) not in sys.path
loader = AgentLoader(str(temp_path))
# Path should not be added yet - only added during load
assert str(temp_path) not in sys.path
# Load agent - this should add the path
agent = loader.load_agent("path_agent")
# Now assert path was added
assert str(temp_path) in sys.path
assert agent.name == "path_agent"