refactor: refactor cli to use agent loader

PiperOrigin-RevId: 764546019
This commit is contained in:
Xiang (Sean) Zhou 2025-05-28 21:08:55 -07:00 committed by Copybara-Service
parent 623957c0a8
commit 791ac9eb68
2 changed files with 19 additions and 30 deletions

View File

@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations
from datetime import datetime from datetime import datetime
import importlib
import os
import sys
from typing import Optional from typing import Optional
import click import click
@ -30,6 +29,7 @@ from ..sessions.base_session_service import BaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService from ..sessions.in_memory_session_service import InMemorySessionService
from ..sessions.session import Session from ..sessions.session import Session
from .utils import envs from .utils import envs
from .utils.agent_loader import AgentLoader
class InputFile(BaseModel): class InputFile(BaseModel):
@ -122,19 +122,17 @@ async def run_cli(
save_session: bool, whether to save the session on exit. save_session: bool, whether to save the session on exit.
session_id: Optional[str], the session ID to save the session to on exit. session_id: Optional[str], the session ID to save the session to on exit.
""" """
if agent_parent_dir not in sys.path:
sys.path.append(agent_parent_dir)
artifact_service = InMemoryArtifactService() artifact_service = InMemoryArtifactService()
session_service = InMemorySessionService() session_service = InMemorySessionService()
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
agent_module = importlib.import_module(agent_folder_name)
user_id = 'test_user' user_id = 'test_user'
session = await session_service.create_session( session = await session_service.create_session(
app_name=agent_folder_name, user_id=user_id app_name=agent_folder_name, user_id=user_id
) )
root_agent = agent_module.agent.root_agent root_agent = AgentLoader(agents_dir=agent_parent_dir).load_agent(
agent_folder_name
)
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir) envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
if input_file: if input_file:
session = await run_input_file( session = await run_input_file(
@ -177,7 +175,9 @@ async def run_cli(
if save_session: if save_session:
session_id = session_id or input('Session ID to save: ') session_id = session_id or input('Session ID to save: ')
session_path = f'{agent_module_path}/{session_id}.session.json' session_path = (
f'{agent_parent_dir}/{agent_folder_name}/{session_id}.session.json'
)
# Fetch the session again to get all the details. # Fetch the session again to get all the details.
session = await session_service.get_session( session = await session_service.get_session(

View File

@ -18,7 +18,7 @@ from __future__ import annotations
import json import json
from pathlib import Path from pathlib import Path
import sys from textwrap import dedent
import types import types
from typing import Any from typing import Any
from typing import Dict from typing import Dict
@ -87,7 +87,7 @@ def _patch_types_and_runner(monkeypatch: pytest.MonkeyPatch) -> None:
@pytest.fixture() @pytest.fixture()
def fake_agent(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): def fake_agent(tmp_path: Path):
"""Create a minimal importable agent package and patch importlib.""" """Create a minimal importable agent package and patch importlib."""
parent_dir = tmp_path / "agents" parent_dir = tmp_path / "agents"
@ -95,27 +95,16 @@ def fake_agent(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
agent_dir = parent_dir / "fake_agent" agent_dir = parent_dir / "fake_agent"
agent_dir.mkdir() agent_dir.mkdir()
# __init__.py exposes root_agent with .name # __init__.py exposes root_agent with .name
(agent_dir / "__init__.py").write_text( (agent_dir / "__init__.py").write_text(dedent("""
"from types import SimpleNamespace\n" from google.adk.agents.base_agent import BaseAgent
"root_agent = SimpleNamespace(name='fake_root')\n" class FakeAgent(BaseAgent):
) def __init__(self, name):
super().__init__(name=name)
# Ensure importable via sys.path root_agent = FakeAgent(name="fake_root")
sys.path.insert(0, str(parent_dir)) """))
import importlib return parent_dir, "fake_agent"
module = importlib.import_module("fake_agent")
fake_module = types.SimpleNamespace(agent=module)
monkeypatch.setattr(importlib, "import_module", lambda n: fake_module)
monkeypatch.setattr(cli.envs, "load_dotenv_for_agent", lambda *a, **k: None)
yield parent_dir, "fake_agent"
# Cleanup
sys.path.remove(str(parent_dir))
del sys.modules["fake_agent"]
# _run_input_file # _run_input_file