From 791ac9eb68c14a5674f77d33b55bd5b3a55dfbe8 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 28 May 2025 21:08:55 -0700 Subject: [PATCH] refactor: refactor cli to use agent loader PiperOrigin-RevId: 764546019 --- src/google/adk/cli/cli.py | 18 ++++++++-------- tests/unittests/cli/utils/test_cli.py | 31 +++++++++------------------ 2 files changed, 19 insertions(+), 30 deletions(-) diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index d23cd7e..aceb3fc 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from datetime import datetime -import importlib -import os -import sys from typing import Optional import click @@ -30,6 +29,7 @@ from ..sessions.base_session_service import BaseSessionService from ..sessions.in_memory_session_service import InMemorySessionService from ..sessions.session import Session from .utils import envs +from .utils.agent_loader import AgentLoader class InputFile(BaseModel): @@ -122,19 +122,17 @@ async def run_cli( save_session: bool, whether to save the session on exit. session_id: Optional[str], the session ID to save the session to on exit. """ - if agent_parent_dir not in sys.path: - sys.path.append(agent_parent_dir) artifact_service = InMemoryArtifactService() session_service = InMemorySessionService() - agent_module_path = os.path.join(agent_parent_dir, agent_folder_name) - agent_module = importlib.import_module(agent_folder_name) user_id = 'test_user' session = await session_service.create_session( app_name=agent_folder_name, user_id=user_id ) - root_agent = agent_module.agent.root_agent + root_agent = AgentLoader(agents_dir=agent_parent_dir).load_agent( + agent_folder_name + ) envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir) if input_file: session = await run_input_file( @@ -177,7 +175,9 @@ async def run_cli( if save_session: session_id = session_id or input('Session ID to save: ') - session_path = f'{agent_module_path}/{session_id}.session.json' + session_path = ( + f'{agent_parent_dir}/{agent_folder_name}/{session_id}.session.json' + ) # Fetch the session again to get all the details. session = await session_service.get_session( diff --git a/tests/unittests/cli/utils/test_cli.py b/tests/unittests/cli/utils/test_cli.py index d79e199..1721885 100644 --- a/tests/unittests/cli/utils/test_cli.py +++ b/tests/unittests/cli/utils/test_cli.py @@ -18,7 +18,7 @@ from __future__ import annotations import json from pathlib import Path -import sys +from textwrap import dedent import types from typing import Any from typing import Dict @@ -87,7 +87,7 @@ def _patch_types_and_runner(monkeypatch: pytest.MonkeyPatch) -> None: @pytest.fixture() -def fake_agent(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): +def fake_agent(tmp_path: Path): """Create a minimal importable agent package and patch importlib.""" parent_dir = tmp_path / "agents" @@ -95,27 +95,16 @@ def fake_agent(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): agent_dir = parent_dir / "fake_agent" agent_dir.mkdir() # __init__.py exposes root_agent with .name - (agent_dir / "__init__.py").write_text( - "from types import SimpleNamespace\n" - "root_agent = SimpleNamespace(name='fake_root')\n" - ) + (agent_dir / "__init__.py").write_text(dedent(""" + from google.adk.agents.base_agent import BaseAgent + class FakeAgent(BaseAgent): + def __init__(self, name): + super().__init__(name=name) - # Ensure importable via sys.path - sys.path.insert(0, str(parent_dir)) + root_agent = FakeAgent(name="fake_root") + """)) - import importlib - - module = importlib.import_module("fake_agent") - fake_module = types.SimpleNamespace(agent=module) - - monkeypatch.setattr(importlib, "import_module", lambda n: fake_module) - monkeypatch.setattr(cli.envs, "load_dotenv_for_agent", lambda *a, **k: None) - - yield parent_dir, "fake_agent" - - # Cleanup - sys.path.remove(str(parent_dir)) - del sys.modules["fake_agent"] + return parent_dir, "fake_agent" # _run_input_file