refactor: refine agent_loader to load the most common folder structure first and don't including unnecessary case

PiperOrigin-RevId: 764545194
This commit is contained in:
Xiang (Sean) Zhou 2025-05-28 21:05:31 -07:00 committed by Copybara-Service
parent 6157db77f2
commit 623957c0a8

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import importlib import importlib
import logging import logging
import sys import sys
from typing import Optional
from . import envs from . import envs
from ...agents.base_agent import BaseAgent from ...agents.base_agent import BaseAgent
@ -27,9 +28,13 @@ logger = logging.getLogger("google_adk." + __name__)
class AgentLoader: class AgentLoader:
"""Centralized agent loading with proper isolation, caching, and .env loading. """Centralized agent loading with proper isolation, caching, and .env loading.
Support loading agents from below folder/file structures: Support loading agents from below folder/file structures:
a) agents_dir/agent_name.py (with root_agent or agent.root_agent in it) a) {agent_name}.agent as a module name:
b) agents_dir/agent_name_folder/__init__.py (with root_agent or agent.root_agent in the package) agents_dir/{agent_name}/agent.py (with root_agent defined in the module)
c) agents_dir/agent_name_folder/agent.py (where agent.py has root_agent) b) {agent_name} as a module name
agents_dir/{agent_name}.py (with root_agent defined in the module)
c) {agent_name} as a package name
agents_dir/{agent_name}/__init__.py (with root_agent in the package)
""" """
def __init__(self, agents_dir: str): def __init__(self, agents_dir: str):
@ -37,30 +42,26 @@ class AgentLoader:
self._original_sys_path = None self._original_sys_path = None
self._agent_cache: dict[str, BaseAgent] = {} self._agent_cache: dict[str, BaseAgent] = {}
def _load_from_module_or_package(self, agent_name: str) -> BaseAgent: def _load_from_module_or_package(
# Load for case: Import "<agent_name>" (as a package or module) self, agent_name: str
) -> Optional[BaseAgent]:
# Load for case: Import "{agent_name}" (as a package or module)
# Covers structures: # Covers structures:
# a) agents_dir/agent_name.py (with root_agent or agent.root_agent in it) # a) agents_dir/{agent_name}.py (with root_agent in the module)
# b) agents_dir/agent_name_folder/__init__.py (with root_agent or agent.root_agent in the package) # b) agents_dir/{agent_name}/__init__.py (with root_agent in the package)
try: try:
module_candidate = importlib.import_module(agent_name) module_candidate = importlib.import_module(agent_name)
# Check for "root_agent" directly in "<agent_name>" module/package # Check for "root_agent" directly in "{agent_name}" module/package
if hasattr(module_candidate, "root_agent"): if hasattr(module_candidate, "root_agent"):
logger.debug("Found root_agent directly in %s", agent_name) logger.debug("Found root_agent directly in %s", agent_name)
return module_candidate.root_agent if isinstance(module_candidate.root_agent, BaseAgent):
# Check for "<agent_name>.agent.root_agent" structure (e.g. agent_name is a package, return module_candidate.root_agent
# and it has an 'agent' submodule/attribute which in turn has 'root_agent')
if hasattr(module_candidate, "agent") and hasattr(
module_candidate.agent, "root_agent"
):
logger.debug("Found root_agent in %s.agent attribute", agent_name)
if isinstance(module_candidate.agent, BaseAgent):
return module_candidate.agent.root_agent
else: else:
logger.warning( logger.warning(
"Root agent found is not an instance of BaseAgent. But a type %s", "Root agent found is not an instance of BaseAgent. But a type %s",
type(module_candidate.agent), type(module_candidate.root_agent),
) )
except ModuleNotFoundError: except ModuleNotFoundError:
logger.debug("Module %s itself not found.", agent_name) logger.debug("Module %s itself not found.", agent_name)
# Re-raise as ValueError to be caught by the final error message construction # Re-raise as ValueError to be caught by the final error message construction
@ -72,13 +73,13 @@ class AgentLoader:
return None return None
def _load_from_submodule(self, agent_name: str) -> BaseAgent: def _load_from_submodule(self, agent_name: str) -> Optional[BaseAgent]:
# Load for case: Import "<agent_name>.agent" and look for "root_agent" # Load for case: Import "{agent_name}.agent" and look for "root_agent"
# Covers structure: agents_dir/agent_name_folder/agent.py (where agent.py has root_agent) # Covers structure: agents_dir/{agent_name}/agent.py (with root_agent defined in the module)
try: try:
module_candidate = importlib.import_module(f"{agent_name}.agent") module_candidate = importlib.import_module(f"{agent_name}.agent")
if hasattr(module_candidate, "root_agent"): if hasattr(module_candidate, "root_agent"):
logger.debug("Found root_agent in %s.agent", agent_name) logger.info("Found root_agent in %s.agent", agent_name)
if isinstance(module_candidate.root_agent, BaseAgent): if isinstance(module_candidate.root_agent, BaseAgent):
return module_candidate.root_agent return module_candidate.root_agent
else: else:
@ -106,32 +107,28 @@ class AgentLoader:
) )
envs.load_dotenv_for_agent(agent_name, str(self.agents_dir)) envs.load_dotenv_for_agent(agent_name, str(self.agents_dir))
root_agent = self._load_from_module_or_package(agent_name) if root_agent := self._load_from_submodule(agent_name):
if root_agent:
return root_agent return root_agent
root_agent = self._load_from_submodule(agent_name) if root_agent := self._load_from_module_or_package(agent_name):
if root_agent:
return root_agent return root_agent
# If no root_agent was found by any pattern # If no root_agent was found by any pattern
raise ValueError( raise ValueError(
f"No root_agent found for '{agent_name}'. Searched in" f"No root_agent found for '{agent_name}'. Searched in"
f" '{agent_name}.agent.root_agent', '{agent_name}.root_agent', and" f" '{agent_name}.agent.root_agent', '{agent_name}.root_agent'."
f" via an 'agent' attribute within the '{agent_name}' module/package."
f" Ensure '{self.agents_dir}/{agent_name}' is structured correctly," f" Ensure '{self.agents_dir}/{agent_name}' is structured correctly,"
" an .env file can be loaded if present, and a root_agent is" " an .env file can be loaded if present, and a root_agent is"
" exposed." " exposed."
) )
def load_agent(self, agent_name: str) -> BaseAgent: def load_agent(self, agent_name: str) -> BaseAgent:
"""Load an agent module (with caching & .env) and return its root_agent (asynchronously).""" """Load an agent module (with caching & .env) and return its root_agent."""
if agent_name in self._agent_cache: if agent_name in self._agent_cache:
logger.debug("Returning cached agent for %s (async)", agent_name) logger.debug("Returning cached agent for %s (async)", agent_name)
return self._agent_cache[agent_name] return self._agent_cache[agent_name]
logger.debug("Loading agent %s - not in cache.", agent_name) logger.debug("Loading agent %s - not in cache.", agent_name)
# Assumes this method is called when the context manager (`with self:`) is active
agent = self._perform_load(agent_name) agent = self._perform_load(agent_name)
self._agent_cache[agent_name] = agent self._agent_cache[agent_name] = agent
return agent return agent