From 618c8249947180c4e6603cad170368da8f753907 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 27 May 2025 14:14:00 -0700 Subject: [PATCH] feat: extract agent loading logic from fast_api.py to a separate AgentLoader class and support more agent definition folder/file structure. Structures supported: a) agents_dir/agent_name.py (with root_agent or agent.root_agent in it) b) agents_dir/agent_name_folder/__init__.py (with root_agent or agent.root_agent in the package) c) agents_dir/agent_name_folder/agent.py (where agent.py has root_agent) PiperOrigin-RevId: 763943716 --- src/google/adk/cli/fast_api.py | 34 +- src/google/adk/cli/utils/agent_loader.py | 137 ++++++++ .../unittests/cli/utils/test_agent_loader.py | 303 ++++++++++++++++++ 3 files changed, 448 insertions(+), 26 deletions(-) create mode 100644 src/google/adk/cli/utils/agent_loader.py create mode 100644 tests/unittests/cli/utils/test_agent_loader.py diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 5090f05..b4909bc 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -17,11 +17,9 @@ from __future__ import annotations import asyncio from contextlib import asynccontextmanager -import importlib import logging import os from pathlib import Path -import sys import time import traceback import typing @@ -81,11 +79,11 @@ from .utils import common from .utils import create_empty_state from .utils import envs from .utils import evals +from .utils.agent_loader import AgentLoader logger = logging.getLogger("google_adk." + __name__) _EVAL_SET_FILE_EXTENSION = ".evalset.json" -_EVAL_SET_RESULT_FILE_EXTENSION = ".evalset_result.json" class ApiServerSpanExporter(export.SpanExporter): @@ -251,11 +249,7 @@ def get_fast_api_app( allow_headers=["*"], ) - if agents_dir not in sys.path: - sys.path.append(agents_dir) - runner_dict = {} - root_agent_dict = {} # Build the Artifact service artifact_service = InMemoryArtifactService() @@ -282,6 +276,9 @@ def get_fast_api_app( else: session_service = InMemorySessionService() + # initialize Agent Loader + agent_loader = AgentLoader(agents_dir) + @app.get("/list-apps") def list_apps() -> list[str]: base_path = Path.cwd() / agents_dir @@ -450,7 +447,7 @@ def get_fast_api_app( # Populate the session with initial session state. initial_session_state = create_empty_state( - await _get_root_agent_async(app_name) + agent_loader.load_agent(app_name) ) new_eval_case = EvalCase( @@ -492,8 +489,6 @@ def get_fast_api_app( # Create a mapping from eval set file to all the evals that needed to be # run. - envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir) - eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id) if req.eval_ids: @@ -503,7 +498,7 @@ def get_fast_api_app( logger.info("Eval ids to run list is empty. We will run all eval cases.") eval_set_to_evals = {eval_set_id: eval_set.eval_cases} - root_agent = await _get_root_agent_async(app_name) + root_agent = agent_loader.load_agent(app_name) run_eval_results = [] eval_case_results = [] async for eval_case_result in run_evals( @@ -741,7 +736,7 @@ def get_fast_api_app( function_calls = event.get_function_calls() function_responses = event.get_function_responses() - root_agent = await _get_root_agent_async(app_name) + root_agent = agent_loader.load_agent(app_name) dot_graph = None if function_calls: function_call_highlights = [] @@ -842,25 +837,12 @@ def get_fast_api_app( for task in pending: task.cancel() - async def _get_root_agent_async(app_name: str) -> Agent: - """Returns the root agent for the given app.""" - if app_name in root_agent_dict: - return root_agent_dict[app_name] - agent_module = importlib.import_module(app_name) - if getattr(agent_module.agent, "root_agent"): - root_agent = agent_module.agent.root_agent - else: - raise ValueError(f'Unable to find "root_agent" from {app_name}.') - - root_agent_dict[app_name] = root_agent - return root_agent - async def _get_runner_async(app_name: str) -> Runner: """Returns the runner for the given app.""" envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir) if app_name in runner_dict: return runner_dict[app_name] - root_agent = await _get_root_agent_async(app_name) + root_agent = agent_loader.load_agent(app_name) runner = Runner( app_name=agent_engine_id if agent_engine_id else app_name, agent=root_agent, diff --git a/src/google/adk/cli/utils/agent_loader.py b/src/google/adk/cli/utils/agent_loader.py new file mode 100644 index 0000000..69753ae --- /dev/null +++ b/src/google/adk/cli/utils/agent_loader.py @@ -0,0 +1,137 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib +import logging +import sys + +from . import envs +from ...agents.base_agent import BaseAgent + +logger = logging.getLogger("google_adk." + __name__) + + +class AgentLoader: + """Centralized agent loading with proper isolation, caching, and .env loading. + Support loading agents from below folder/file structures: + a) agents_dir/agent_name.py (with root_agent or agent.root_agent in it) + b) agents_dir/agent_name_folder/__init__.py (with root_agent or agent.root_agent in the package) + c) agents_dir/agent_name_folder/agent.py (where agent.py has root_agent) + """ + + def __init__(self, agents_dir: str): + self.agents_dir = agents_dir.rstrip("/") + self._original_sys_path = None + self._agent_cache: dict[str, BaseAgent] = {} + + def _load_from_module_or_package(self, agent_name: str) -> BaseAgent: + # Load for case: Import "" (as a package or module) + # Covers structures: + # a) agents_dir/agent_name.py (with root_agent or agent.root_agent in it) + # b) agents_dir/agent_name_folder/__init__.py (with root_agent or agent.root_agent in the package) + try: + module_candidate = importlib.import_module(agent_name) + # Check for "root_agent" directly in "" module/package + if hasattr(module_candidate, "root_agent"): + logger.debug("Found root_agent directly in %s", agent_name) + return module_candidate.root_agent + # Check for ".agent.root_agent" structure (e.g. agent_name is a package, + # and it has an 'agent' submodule/attribute which in turn has 'root_agent') + if hasattr(module_candidate, "agent") and hasattr( + module_candidate.agent, "root_agent" + ): + logger.debug("Found root_agent in %s.agent attribute", agent_name) + if isinstance(module_candidate.agent, BaseAgent): + return module_candidate.agent.root_agent + else: + logger.warning( + "Root agent found is not an instance of BaseAgent. But a type %s", + type(module_candidate.agent), + ) + except ModuleNotFoundError: + logger.debug("Module %s itself not found.", agent_name) + # Re-raise as ValueError to be caught by the final error message construction + raise ValueError( + f"Module {agent_name} not found during import attempts." + ) from None + except ImportError as e: + logger.warning("Error importing %s: %s", agent_name, e) + + return None + + def _load_from_submodule(self, agent_name: str) -> BaseAgent: + # Load for case: Import ".agent" and look for "root_agent" + # Covers structure: agents_dir/agent_name_folder/agent.py (where agent.py has root_agent) + try: + module_candidate = importlib.import_module(f"{agent_name}.agent") + if hasattr(module_candidate, "root_agent"): + logger.debug("Found root_agent in %s.agent", agent_name) + if isinstance(module_candidate.root_agent, BaseAgent): + return module_candidate.root_agent + else: + logger.warning( + "Root agent found is not an instance of BaseAgent. But a type %s", + type(module_candidate.root_agent), + ) + except ModuleNotFoundError: + logger.debug( + "Module %s.agent not found, trying next pattern.", agent_name + ) + except ImportError as e: + logger.warning("Error importing %s.agent: %s", agent_name, e) + + return None + + def _perform_load(self, agent_name: str) -> BaseAgent: + """Internal logic to load an agent""" + # Add self.agents_dir to sys.path + if self.agents_dir not in sys.path: + sys.path.insert(0, self.agents_dir) + + logger.debug( + "Loading .env for agent %s from %s", agent_name, self.agents_dir + ) + envs.load_dotenv_for_agent(agent_name, str(self.agents_dir)) + + root_agent = self._load_from_module_or_package(agent_name) + if root_agent: + return root_agent + + root_agent = self._load_from_submodule(agent_name) + if root_agent: + return root_agent + + # If no root_agent was found by any pattern + raise ValueError( + f"No root_agent found for '{agent_name}'. Searched in" + f" '{agent_name}.agent.root_agent', '{agent_name}.root_agent', and" + f" via an 'agent' attribute within the '{agent_name}' module/package." + f" Ensure '{self.agents_dir}/{agent_name}' is structured correctly," + " an .env file can be loaded if present, and a root_agent is" + " exposed." + ) + + def load_agent(self, agent_name: str) -> BaseAgent: + """Load an agent module (with caching & .env) and return its root_agent (asynchronously).""" + if agent_name in self._agent_cache: + logger.debug("Returning cached agent for %s (async)", agent_name) + return self._agent_cache[agent_name] + + logger.debug("Loading agent %s - not in cache.", agent_name) + # Assumes this method is called when the context manager (`with self:`) is active + agent = self._perform_load(agent_name) + self._agent_cache[agent_name] = agent + return agent diff --git a/tests/unittests/cli/utils/test_agent_loader.py b/tests/unittests/cli/utils/test_agent_loader.py new file mode 100644 index 0000000..32d7bdc --- /dev/null +++ b/tests/unittests/cli/utils/test_agent_loader.py @@ -0,0 +1,303 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path +import sys +import tempfile +from textwrap import dedent + +from google.adk.cli.utils.agent_loader import AgentLoader +import pytest + + +class TestAgentLoader: + """Unit tests for AgentLoader focusing on interface behavior.""" + + @pytest.fixture(autouse=True) + def cleanup_sys_path(self): + """Ensure sys.path is restored after each test.""" + original_path = sys.path.copy() + original_env = os.environ.copy() + yield + sys.path[:] = original_path + # Restore environment variables + os.environ.clear() + os.environ.update(original_env) + + def create_agent_structure( + self, temp_dir: Path, agent_name: str, structure_type: str + ): + """Create different agent structures for testing. + + Args: + temp_dir: The temporary directory to create the agent in + agent_name: Name of the agent + structure_type: One of 'module', 'package_with_root', 'package_with_agent_module' + """ + if structure_type == "module": + # Structure: agents_dir/agent_name.py + agent_file = temp_dir / f"{agent_name}.py" + agent_file.write_text(dedent(f""" + import os + from google.adk.agents.base_agent import BaseAgent + from typing import Any + + class {agent_name.title()}Agent(BaseAgent): + agent_id: Any = None + config: Any = None + + def __init__(self): + super().__init__(name="{agent_name}") + self.agent_id = id(self) + self.config = os.environ.get("AGENT_CONFIG", "default") + + root_agent = {agent_name.title()}Agent() + + + """)) + + elif structure_type == "package_with_root": + # Structure: agents_dir/agent_name/__init__.py (with root_agent) + agent_dir = temp_dir / agent_name + agent_dir.mkdir() + init_file = agent_dir / "__init__.py" + init_file.write_text(dedent(f""" + import os + from google.adk.agents.base_agent import BaseAgent + from typing import Any + + class {agent_name.title()}Agent(BaseAgent): + agent_id: Any = None + config: Any = None + + def __init__(self): + super().__init__(name="{agent_name}") + self.agent_id = id(self) + self.config = os.environ.get("AGENT_CONFIG", "default") + + root_agent = {agent_name.title()}Agent() + """)) + + elif structure_type == "package_with_agent_module": + # Structure: agents_dir/agent_name/agent.py + agent_dir = temp_dir / agent_name + agent_dir.mkdir() + + # Create __init__.py + init_file = agent_dir / "__init__.py" + init_file.write_text("") + + # Create agent.py with root_agent + agent_file = agent_dir / "agent.py" + agent_file.write_text(dedent(f""" + import os + from google.adk.agents.base_agent import BaseAgent + from typing import Any + + class {agent_name.title()}Agent(BaseAgent): + agent_id: Any = None + config: Any = None + + def __init__(self): + super().__init__(name="{agent_name}") + self.agent_id = id(self) + self.config = os.environ.get("AGENT_CONFIG", "default") + + root_agent = {agent_name.title()}Agent() + """)) + + def create_env_file(self, temp_dir: Path, agent_name: str, env_vars: dict): + """Create a .env file for the agent.""" + env_file = temp_dir / agent_name / ".env" + env_file.parent.mkdir(exist_ok=True) + + env_content = "\n".join( + [f"{key}={value}" for key, value in env_vars.items()] + ) + env_file.write_text(env_content) + + def test_load_agent_as_module(self): + """Test loading an agent structured as a single module file.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create agent as module + self.create_agent_structure(temp_path, "module_agent", "module") + + # Load the agent + loader = AgentLoader(str(temp_path)) + agent = loader.load_agent("module_agent") + + # Assert agent was loaded correctly + assert agent.name == "module_agent" + assert hasattr(agent, "agent_id") + assert agent.config == "default" + + def test_load_agent_as_package_with_root_agent(self): + """Test loading an agent structured as a package with root_agent in __init__.py.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create agent as package + self.create_agent_structure( + temp_path, "package_agent", "package_with_root" + ) + + # Load the agent + loader = AgentLoader(str(temp_path)) + agent = loader.load_agent("package_agent") + + # Assert agent was loaded correctly + assert agent.name == "package_agent" + assert hasattr(agent, "agent_id") + + def test_load_agent_as_package_with_agent_module(self): + """Test loading an agent structured as a package with separate agent.py module.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create agent as package with agent.py + self.create_agent_structure( + temp_path, "modular_agent", "package_with_agent_module" + ) + + # Load the agent + loader = AgentLoader(str(temp_path)) + agent = loader.load_agent("modular_agent") + + # Assert agent was loaded correctly + assert agent.name == "modular_agent" + assert hasattr(agent, "agent_id") + + def test_agent_caching_returns_same_instance(self): + """Test that loading the same agent twice returns the same instance.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create agent + self.create_agent_structure(temp_path, "cached_agent", "module") + + # Load the agent twice + loader = AgentLoader(str(temp_path)) + agent1 = loader.load_agent("cached_agent") + agent2 = loader.load_agent("cached_agent") + + # Assert same instance is returned + assert agent1 is agent2 + assert agent1.agent_id == agent2.agent_id + + def test_env_loading_for_agent(self): + """Test that .env file is loaded for the agent.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create agent and .env file + self.create_agent_structure(temp_path, "env_agent", "package_with_root") + self.create_env_file( + temp_path, + "env_agent", + {"AGENT_CONFIG": "production", "AGENT_SECRET": "test_secret_123"}, + ) + + # Load the agent + loader = AgentLoader(str(temp_path)) + agent = loader.load_agent("env_agent") + + # Assert environment variables were loaded + assert agent.config == "production" + assert os.environ.get("AGENT_SECRET") == "test_secret_123" + + def test_load_multiple_different_agents(self): + """Test loading multiple different agents.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create multiple agents with different structures + self.create_agent_structure(temp_path, "agent_one", "module") + self.create_agent_structure(temp_path, "agent_two", "package_with_root") + self.create_agent_structure( + temp_path, "agent_three", "package_with_agent_module" + ) + + # Load all agents + loader = AgentLoader(str(temp_path)) + agent1 = loader.load_agent("agent_one") + agent2 = loader.load_agent("agent_two") + agent3 = loader.load_agent("agent_three") + + # Assert all agents were loaded correctly and are different instances + assert agent1.name == "agent_one" + assert agent2.name == "agent_two" + assert agent3.name == "agent_three" + assert agent1 is not agent2 + assert agent2 is not agent3 + assert agent1.agent_id != agent2.agent_id != agent3.agent_id + + def test_agent_not_found_error(self): + """Test that appropriate error is raised when agent is not found.""" + with tempfile.TemporaryDirectory() as temp_dir: + loader = AgentLoader(temp_dir) + + # Try to load non-existent agent + with pytest.raises(ValueError) as exc_info: + loader.load_agent("nonexistent_agent") + + assert "Module nonexistent_agent not found" in str(exc_info.value) + + def test_agent_without_root_agent_error(self): + """Test that appropriate error is raised when agent has no root_agent.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create agent without root_agent + agent_file = temp_path / "broken_agent.py" + agent_file.write_text(dedent(""" + class BrokenAgent: + def __init__(self): + self.name = "broken" + + # Note: No root_agent defined + """)) + + loader = AgentLoader(str(temp_path)) + + # Try to load agent without root_agent + with pytest.raises(ValueError) as exc_info: + loader.load_agent("broken_agent") + + assert "No root_agent found for 'broken_agent'" in str(exc_info.value) + + def test_sys_path_modification(self): + """Test that agents_dir is added to sys.path correctly.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create agent + self.create_agent_structure(temp_path, "path_agent", "module") + + # Check sys.path before + assert str(temp_path) not in sys.path + + loader = AgentLoader(str(temp_path)) + + # Path should not be added yet - only added during load + assert str(temp_path) not in sys.path + + # Load agent - this should add the path + agent = loader.load_agent("path_agent") + + # Now assert path was added + assert str(temp_path) in sys.path + assert agent.name == "path_agent"