refactor: refine agent_loader to load the most common folder structure first and don't including unnecessary case

PiperOrigin-RevId: 764545194
This commit is contained in:
Xiang (Sean) Zhou 2025-05-28 21:05:31 -07:00 committed by Copybara-Service
parent 6157db77f2
commit 623957c0a8

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import importlib
import logging
import sys
from typing import Optional
from . import envs
from ...agents.base_agent import BaseAgent
@ -27,9 +28,13 @@ logger = logging.getLogger("google_adk." + __name__)
class AgentLoader:
"""Centralized agent loading with proper isolation, caching, and .env loading.
Support loading agents from below folder/file structures:
a) agents_dir/agent_name.py (with root_agent or agent.root_agent in it)
b) agents_dir/agent_name_folder/__init__.py (with root_agent or agent.root_agent in the package)
c) agents_dir/agent_name_folder/agent.py (where agent.py has root_agent)
a) {agent_name}.agent as a module name:
agents_dir/{agent_name}/agent.py (with root_agent defined in the module)
b) {agent_name} as a module name
agents_dir/{agent_name}.py (with root_agent defined in the module)
c) {agent_name} as a package name
agents_dir/{agent_name}/__init__.py (with root_agent in the package)
"""
def __init__(self, agents_dir: str):
@ -37,30 +42,26 @@ class AgentLoader:
self._original_sys_path = None
self._agent_cache: dict[str, BaseAgent] = {}
def _load_from_module_or_package(self, agent_name: str) -> BaseAgent:
# Load for case: Import "<agent_name>" (as a package or module)
def _load_from_module_or_package(
self, agent_name: str
) -> Optional[BaseAgent]:
# Load for case: Import "{agent_name}" (as a package or module)
# Covers structures:
# a) agents_dir/agent_name.py (with root_agent or agent.root_agent in it)
# b) agents_dir/agent_name_folder/__init__.py (with root_agent or agent.root_agent in the package)
# a) agents_dir/{agent_name}.py (with root_agent in the module)
# b) agents_dir/{agent_name}/__init__.py (with root_agent in the package)
try:
module_candidate = importlib.import_module(agent_name)
# Check for "root_agent" directly in "<agent_name>" module/package
# Check for "root_agent" directly in "{agent_name}" module/package
if hasattr(module_candidate, "root_agent"):
logger.debug("Found root_agent directly in %s", agent_name)
return module_candidate.root_agent
# Check for "<agent_name>.agent.root_agent" structure (e.g. agent_name is a package,
# and it has an 'agent' submodule/attribute which in turn has 'root_agent')
if hasattr(module_candidate, "agent") and hasattr(
module_candidate.agent, "root_agent"
):
logger.debug("Found root_agent in %s.agent attribute", agent_name)
if isinstance(module_candidate.agent, BaseAgent):
return module_candidate.agent.root_agent
if isinstance(module_candidate.root_agent, BaseAgent):
return module_candidate.root_agent
else:
logger.warning(
"Root agent found is not an instance of BaseAgent. But a type %s",
type(module_candidate.agent),
type(module_candidate.root_agent),
)
except ModuleNotFoundError:
logger.debug("Module %s itself not found.", agent_name)
# Re-raise as ValueError to be caught by the final error message construction
@ -72,13 +73,13 @@ class AgentLoader:
return None
def _load_from_submodule(self, agent_name: str) -> BaseAgent:
# Load for case: Import "<agent_name>.agent" and look for "root_agent"
# Covers structure: agents_dir/agent_name_folder/agent.py (where agent.py has root_agent)
def _load_from_submodule(self, agent_name: str) -> Optional[BaseAgent]:
# Load for case: Import "{agent_name}.agent" and look for "root_agent"
# Covers structure: agents_dir/{agent_name}/agent.py (with root_agent defined in the module)
try:
module_candidate = importlib.import_module(f"{agent_name}.agent")
if hasattr(module_candidate, "root_agent"):
logger.debug("Found root_agent in %s.agent", agent_name)
logger.info("Found root_agent in %s.agent", agent_name)
if isinstance(module_candidate.root_agent, BaseAgent):
return module_candidate.root_agent
else:
@ -106,32 +107,28 @@ class AgentLoader:
)
envs.load_dotenv_for_agent(agent_name, str(self.agents_dir))
root_agent = self._load_from_module_or_package(agent_name)
if root_agent:
if root_agent := self._load_from_submodule(agent_name):
return root_agent
root_agent = self._load_from_submodule(agent_name)
if root_agent:
if root_agent := self._load_from_module_or_package(agent_name):
return root_agent
# If no root_agent was found by any pattern
raise ValueError(
f"No root_agent found for '{agent_name}'. Searched in"
f" '{agent_name}.agent.root_agent', '{agent_name}.root_agent', and"
f" via an 'agent' attribute within the '{agent_name}' module/package."
f" '{agent_name}.agent.root_agent', '{agent_name}.root_agent'."
f" Ensure '{self.agents_dir}/{agent_name}' is structured correctly,"
" an .env file can be loaded if present, and a root_agent is"
" exposed."
)
def load_agent(self, agent_name: str) -> BaseAgent:
"""Load an agent module (with caching & .env) and return its root_agent (asynchronously)."""
"""Load an agent module (with caching & .env) and return its root_agent."""
if agent_name in self._agent_cache:
logger.debug("Returning cached agent for %s (async)", agent_name)
return self._agent_cache[agent_name]
logger.debug("Loading agent %s - not in cache.", agent_name)
# Assumes this method is called when the context manager (`with self:`) is active
agent = self._perform_load(agent_name)
self._agent_cache[agent_name] = agent
return agent