chore: reformat the codes using autoformat.sh

PiperOrigin-RevId: 762004002
This commit is contained in:
Xiang (Sean) Zhou 2025-05-22 09:43:03 -07:00 committed by Copybara-Service
parent a2263b1808
commit ff8a3c9b43
23 changed files with 496 additions and 447 deletions

View File

@ -21,6 +21,7 @@ from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
from unittest import mock from unittest import mock
from google.adk.agents.base_agent import BaseAgent from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.callback_context import CallbackContext from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.invocation_context import InvocationContext
@ -30,6 +31,7 @@ from google.genai import types
import pytest import pytest
import pytest_mock import pytest_mock
from typing_extensions import override from typing_extensions import override
from .. import testing_utils from .. import testing_utils

View File

@ -1,7 +1,11 @@
import pytest from unittest.mock import AsyncMock
from unittest.mock import MagicMock, AsyncMock, patch from unittest.mock import MagicMock
from google.adk.agents.live_request_queue import LiveRequest, LiveRequestQueue from unittest.mock import patch
from google.adk.agents.live_request_queue import LiveRequest
from google.adk.agents.live_request_queue import LiveRequestQueue
from google.genai import types from google.genai import types
import pytest
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -15,7 +15,8 @@
"""Unit tests for canonical_xxx fields in LlmAgent.""" """Unit tests for canonical_xxx fields in LlmAgent."""
from typing import Any from typing import Any
from typing import Optional, cast from typing import cast
from typing import Optional
from google.adk.agents.callback_context import CallbackContext from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.invocation_context import InvocationContext
@ -146,6 +147,7 @@ async def test_canonical_global_instruction():
assert canonical_global_instruction == 'global instruction: state_value' assert canonical_global_instruction == 'global instruction: state_value'
assert bypass_state_injection assert bypass_state_injection
async def test_async_canonical_global_instruction(): async def test_async_canonical_global_instruction():
async def _global_instruction_provider(ctx: ReadonlyContext) -> str: async def _global_instruction_provider(ctx: ReadonlyContext) -> str:
return f'global instruction: {ctx.state["state_var"]}' return f'global instruction: {ctx.state["state_var"]}'

View File

@ -1,7 +1,8 @@
import pytest
from unittest.mock import MagicMock
from types import MappingProxyType from types import MappingProxyType
from unittest.mock import MagicMock
from google.adk.agents.readonly_context import ReadonlyContext from google.adk.agents.readonly_context import ReadonlyContext
import pytest
@pytest.fixture @pytest.fixture

View File

@ -1,8 +1,10 @@
import pytest
import sys
import logging import logging
from unittest.mock import patch, ANY import sys
from unittest.mock import ANY
from unittest.mock import patch
from google.adk.agents.run_config import RunConfig from google.adk.agents.run_config import RunConfig
import pytest
def test_validate_max_llm_calls_valid(): def test_validate_max_llm_calls_valid():

View File

@ -17,12 +17,11 @@
import enum import enum
from typing import Optional from typing import Optional
from typing import Union from typing import Union
from unittest import mock
from google.adk.artifacts import GcsArtifactService from google.adk.artifacts import GcsArtifactService
from google.adk.artifacts import InMemoryArtifactService from google.adk.artifacts import InMemoryArtifactService
from google.genai import types from google.genai import types
from unittest import mock
import pytest import pytest
Enum = enum.Enum Enum = enum.Enum

View File

@ -15,19 +15,18 @@
import copy import copy
from unittest.mock import patch from unittest.mock import patch
import pytest
from fastapi.openapi.models import APIKey from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn from fastapi.openapi.models import APIKeyIn
from fastapi.openapi.models import OAuth2 from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import OAuthFlowAuthorizationCode from fastapi.openapi.models import OAuthFlowAuthorizationCode
from fastapi.openapi.models import OAuthFlows from fastapi.openapi.models import OAuthFlows
from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import OAuth2Auth from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_handler import AuthHandler from google.adk.auth.auth_handler import AuthHandler
from google.adk.auth.auth_schemes import OpenIdConnectWithConfig from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
from google.adk.auth.auth_tool import AuthConfig from google.adk.auth.auth_tool import AuthConfig
import pytest
# Mock classes for testing # Mock classes for testing

View File

@ -16,182 +16,195 @@
from __future__ import annotations from __future__ import annotations
import click
import json import json
import pytest from pathlib import Path
import sys import sys
import types import types
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
import click
import google.adk.cli.cli as cli import google.adk.cli.cli as cli
import pytest
from pathlib import Path
from typing import Any, Dict, List, Tuple
# Helpers # Helpers
class _Recorder: class _Recorder:
"""Callable that records every invocation.""" """Callable that records every invocation."""
def __init__(self) -> None: def __init__(self) -> None:
self.calls: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] self.calls: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = []
def __call__(self, *args: Any, **kwargs: Any) -> None: def __call__(self, *args: Any, **kwargs: Any) -> None:
self.calls.append((args, kwargs)) self.calls.append((args, kwargs))
# Fixtures # Fixtures
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _mute_click(monkeypatch: pytest.MonkeyPatch) -> None: def _mute_click(monkeypatch: pytest.MonkeyPatch) -> None:
"""Silence click output in every test.""" """Silence click output in every test."""
monkeypatch.setattr(click, "echo", lambda *a, **k: None) monkeypatch.setattr(click, "echo", lambda *a, **k: None)
monkeypatch.setattr(click, "secho", lambda *a, **k: None) monkeypatch.setattr(click, "secho", lambda *a, **k: None)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _patch_types_and_runner(monkeypatch: pytest.MonkeyPatch) -> None: def _patch_types_and_runner(monkeypatch: pytest.MonkeyPatch) -> None:
"""Replace google.genai.types and Runner with lightweight fakes.""" """Replace google.genai.types and Runner with lightweight fakes."""
# Dummy Part / Content # Dummy Part / Content
class _Part: class _Part:
def __init__(self, text: str | None = "") -> None:
self.text = text
class _Content: def __init__(self, text: str | None = "") -> None:
def __init__(self, role: str, parts: List[_Part]) -> None: self.text = text
self.role = role
self.parts = parts
monkeypatch.setattr(cli.types, "Part", _Part) class _Content:
monkeypatch.setattr(cli.types, "Content", _Content)
# Fake Runner yielding a single assistant echo def __init__(self, role: str, parts: List[_Part]) -> None:
class _FakeRunner: self.role = role
def __init__(self, *a: Any, **k: Any) -> None: ... self.parts = parts
async def run_async(self, *a: Any, **k: Any): monkeypatch.setattr(cli.types, "Part", _Part)
message = a[2] if len(a) >= 3 else k["new_message"] monkeypatch.setattr(cli.types, "Content", _Content)
text = message.parts[0].text if message.parts else ""
response = _Content("assistant", [_Part(f"echo:{text}")])
yield types.SimpleNamespace(author="assistant", content=response)
monkeypatch.setattr(cli, "Runner", _FakeRunner) # Fake Runner yielding a single assistant echo
class _FakeRunner:
def __init__(self, *a: Any, **k: Any) -> None:
...
async def run_async(self, *a: Any, **k: Any):
message = a[2] if len(a) >= 3 else k["new_message"]
text = message.parts[0].text if message.parts else ""
response = _Content("assistant", [_Part(f"echo:{text}")])
yield types.SimpleNamespace(author="assistant", content=response)
monkeypatch.setattr(cli, "Runner", _FakeRunner)
@pytest.fixture() @pytest.fixture()
def fake_agent(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): def fake_agent(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
"""Create a minimal importable agent package and patch importlib.""" """Create a minimal importable agent package and patch importlib."""
parent_dir = tmp_path / "agents" parent_dir = tmp_path / "agents"
parent_dir.mkdir() parent_dir.mkdir()
agent_dir = parent_dir / "fake_agent" agent_dir = parent_dir / "fake_agent"
agent_dir.mkdir() agent_dir.mkdir()
# __init__.py exposes root_agent with .name # __init__.py exposes root_agent with .name
(agent_dir / "__init__.py").write_text( (agent_dir / "__init__.py").write_text(
"from types import SimpleNamespace\n" "from types import SimpleNamespace\n"
"root_agent = SimpleNamespace(name='fake_root')\n" "root_agent = SimpleNamespace(name='fake_root')\n"
) )
# Ensure importable via sys.path # Ensure importable via sys.path
sys.path.insert(0, str(parent_dir)) sys.path.insert(0, str(parent_dir))
import importlib import importlib
module = importlib.import_module("fake_agent") module = importlib.import_module("fake_agent")
fake_module = types.SimpleNamespace(agent=module) fake_module = types.SimpleNamespace(agent=module)
monkeypatch.setattr(importlib, "import_module", lambda n: fake_module) monkeypatch.setattr(importlib, "import_module", lambda n: fake_module)
monkeypatch.setattr(cli.envs, "load_dotenv_for_agent", lambda *a, **k: None) monkeypatch.setattr(cli.envs, "load_dotenv_for_agent", lambda *a, **k: None)
yield parent_dir, "fake_agent" yield parent_dir, "fake_agent"
# Cleanup # Cleanup
sys.path.remove(str(parent_dir)) sys.path.remove(str(parent_dir))
del sys.modules["fake_agent"] del sys.modules["fake_agent"]
# _run_input_file # _run_input_file
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_input_file_outputs(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: async def test_run_input_file_outputs(
"""run_input_file should echo user & assistant messages and return a populated session.""" tmp_path: Path, monkeypatch: pytest.MonkeyPatch
recorder: List[str] = [] ) -> None:
"""run_input_file should echo user & assistant messages and return a populated session."""
recorder: List[str] = []
def _echo(msg: str) -> None: def _echo(msg: str) -> None:
recorder.append(msg) recorder.append(msg)
monkeypatch.setattr(click, "echo", _echo) monkeypatch.setattr(click, "echo", _echo)
input_json = { input_json = {
"state": {"foo": "bar"}, "state": {"foo": "bar"},
"queries": ["hello world"], "queries": ["hello world"],
} }
input_path = tmp_path / "input.json" input_path = tmp_path / "input.json"
input_path.write_text(json.dumps(input_json)) input_path.write_text(json.dumps(input_json))
artifact_service = cli.InMemoryArtifactService() artifact_service = cli.InMemoryArtifactService()
session_service = cli.InMemorySessionService() session_service = cli.InMemorySessionService()
dummy_root = types.SimpleNamespace(name="root") dummy_root = types.SimpleNamespace(name="root")
session = await cli.run_input_file( session = await cli.run_input_file(
app_name="app", app_name="app",
user_id="user", user_id="user",
root_agent=dummy_root, root_agent=dummy_root,
artifact_service=artifact_service, artifact_service=artifact_service,
session_service=session_service, session_service=session_service,
input_path=str(input_path), input_path=str(input_path),
) )
assert session.state["foo"] == "bar" assert session.state["foo"] == "bar"
assert any("[user]:" in line for line in recorder) assert any("[user]:" in line for line in recorder)
assert any("[assistant]:" in line for line in recorder) assert any("[assistant]:" in line for line in recorder)
# _run_cli (input_file branch) # _run_cli (input_file branch)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_cli_with_input_file(fake_agent, tmp_path: Path) -> None: async def test_run_cli_with_input_file(fake_agent, tmp_path: Path) -> None:
"""run_cli should process an input file without raising and without saving.""" """run_cli should process an input file without raising and without saving."""
parent_dir, folder_name = fake_agent parent_dir, folder_name = fake_agent
input_json = {"state": {}, "queries": ["ping"]} input_json = {"state": {}, "queries": ["ping"]}
input_path = tmp_path / "in.json" input_path = tmp_path / "in.json"
input_path.write_text(json.dumps(input_json)) input_path.write_text(json.dumps(input_json))
await cli.run_cli( await cli.run_cli(
agent_parent_dir=str(parent_dir), agent_parent_dir=str(parent_dir),
agent_folder_name=folder_name, agent_folder_name=folder_name,
input_file=str(input_path), input_file=str(input_path),
saved_session_file=None, saved_session_file=None,
save_session=False, save_session=False,
) )
# _run_cli (interactive + save session branch) # _run_cli (interactive + save session branch)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_cli_save_session(fake_agent, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: async def test_run_cli_save_session(
"""run_cli should save a session file when save_session=True.""" fake_agent, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
parent_dir, folder_name = fake_agent ) -> None:
"""run_cli should save a session file when save_session=True."""
parent_dir, folder_name = fake_agent
# Simulate user typing 'exit' followed by session id 'sess123' # Simulate user typing 'exit' followed by session id 'sess123'
responses = iter(["exit", "sess123"]) responses = iter(["exit", "sess123"])
monkeypatch.setattr("builtins.input", lambda *_a, **_k: next(responses)) monkeypatch.setattr("builtins.input", lambda *_a, **_k: next(responses))
session_file = Path(parent_dir) / folder_name / "sess123.session.json" session_file = Path(parent_dir) / folder_name / "sess123.session.json"
if session_file.exists(): if session_file.exists():
session_file.unlink() session_file.unlink()
await cli.run_cli( await cli.run_cli(
agent_parent_dir=str(parent_dir), agent_parent_dir=str(parent_dir),
agent_folder_name=folder_name, agent_folder_name=folder_name,
input_file=None, input_file=None,
saved_session_file=None, saved_session_file=None,
save_session=True, save_session=True,
) )
assert session_file.exists() assert session_file.exists()
data = json.loads(session_file.read_text()) data = json.loads(session_file.read_text())
# The saved JSON should at least contain id and events keys # The saved JSON should at least contain id and events keys
assert "id" in data and "events" in data assert "id" in data and "events" in data
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_interactively_whitespace_and_exit(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: async def test_run_interactively_whitespace_and_exit(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""run_interactively should skip blank input, echo once, then exit.""" """run_interactively should skip blank input, echo once, then exit."""
# make a session that belongs to dummy agent # make a session that belongs to dummy agent
svc = cli.InMemorySessionService() svc = cli.InMemorySessionService()

View File

@ -17,214 +17,239 @@
from __future__ import annotations from __future__ import annotations
import click
import os import os
import pytest
import subprocess
import google.adk.cli.cli_create as cli_create
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple import subprocess
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
import click
import google.adk.cli.cli_create as cli_create
import pytest
# Helpers # Helpers
class _Recorder: class _Recorder:
"""A callable object that records every invocation.""" """A callable object that records every invocation."""
def __init__(self) -> None: def __init__(self) -> None:
self.calls: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] self.calls: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = []
def __call__(self, *args: Any, **kwargs: Any) -> None: # noqa: D401 def __call__(self, *args: Any, **kwargs: Any) -> None: # noqa: D401
self.calls.append((args, kwargs)) self.calls.append((args, kwargs))
# Fixtures # Fixtures
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _mute_click(monkeypatch: pytest.MonkeyPatch) -> None: def _mute_click(monkeypatch: pytest.MonkeyPatch) -> None:
"""Silence click output in every test.""" """Silence click output in every test."""
monkeypatch.setattr(click, "echo", lambda *a, **k: None) monkeypatch.setattr(click, "echo", lambda *a, **k: None)
monkeypatch.setattr(click, "secho", lambda *a, **k: None) monkeypatch.setattr(click, "secho", lambda *a, **k: None)
@pytest.fixture() @pytest.fixture()
def agent_folder(tmp_path: Path) -> Path: def agent_folder(tmp_path: Path) -> Path:
"""Return a temporary path that will hold generated agent sources.""" """Return a temporary path that will hold generated agent sources."""
return tmp_path / "agent" return tmp_path / "agent"
# _generate_files # _generate_files
def test_generate_files_with_api_key(agent_folder: Path) -> None: def test_generate_files_with_api_key(agent_folder: Path) -> None:
"""Files should be created with the API-key backend and correct .env flags.""" """Files should be created with the API-key backend and correct .env flags."""
cli_create._generate_files( cli_create._generate_files(
str(agent_folder), str(agent_folder),
google_api_key="dummy-key", google_api_key="dummy-key",
model="gemini-2.0-flash-001", model="gemini-2.0-flash-001",
) )
env_content = (agent_folder / ".env").read_text() env_content = (agent_folder / ".env").read_text()
assert "GOOGLE_API_KEY=dummy-key" in env_content assert "GOOGLE_API_KEY=dummy-key" in env_content
assert "GOOGLE_GENAI_USE_VERTEXAI=0" in env_content assert "GOOGLE_GENAI_USE_VERTEXAI=0" in env_content
assert (agent_folder / "agent.py").exists() assert (agent_folder / "agent.py").exists()
assert (agent_folder / "__init__.py").exists() assert (agent_folder / "__init__.py").exists()
def test_generate_files_with_gcp(agent_folder: Path) -> None: def test_generate_files_with_gcp(agent_folder: Path) -> None:
"""Files should be created with Vertex AI backend and correct .env flags.""" """Files should be created with Vertex AI backend and correct .env flags."""
cli_create._generate_files( cli_create._generate_files(
str(agent_folder), str(agent_folder),
google_cloud_project="proj", google_cloud_project="proj",
google_cloud_region="us-central1", google_cloud_region="us-central1",
model="gemini-2.0-flash-001", model="gemini-2.0-flash-001",
) )
env_content = (agent_folder / ".env").read_text() env_content = (agent_folder / ".env").read_text()
assert "GOOGLE_CLOUD_PROJECT=proj" in env_content assert "GOOGLE_CLOUD_PROJECT=proj" in env_content
assert "GOOGLE_CLOUD_LOCATION=us-central1" in env_content assert "GOOGLE_CLOUD_LOCATION=us-central1" in env_content
assert "GOOGLE_GENAI_USE_VERTEXAI=1" in env_content assert "GOOGLE_GENAI_USE_VERTEXAI=1" in env_content
def test_generate_files_overwrite(agent_folder: Path) -> None: def test_generate_files_overwrite(agent_folder: Path) -> None:
"""Existing files should be overwritten when generating again.""" """Existing files should be overwritten when generating again."""
agent_folder.mkdir(parents=True, exist_ok=True) agent_folder.mkdir(parents=True, exist_ok=True)
(agent_folder / ".env").write_text("OLD") (agent_folder / ".env").write_text("OLD")
cli_create._generate_files( cli_create._generate_files(
str(agent_folder), str(agent_folder),
google_api_key="new-key", google_api_key="new-key",
model="gemini-2.0-flash-001", model="gemini-2.0-flash-001",
) )
assert "GOOGLE_API_KEY=new-key" in (agent_folder / ".env").read_text() assert "GOOGLE_API_KEY=new-key" in (agent_folder / ".env").read_text()
def test_generate_files_permission_error(monkeypatch: pytest.MonkeyPatch, agent_folder: Path) -> None: def test_generate_files_permission_error(
"""PermissionError raised by os.makedirs should propagate.""" monkeypatch: pytest.MonkeyPatch, agent_folder: Path
monkeypatch.setattr(os, "makedirs", lambda *a, **k: (_ for _ in ()).throw(PermissionError())) ) -> None:
with pytest.raises(PermissionError): """PermissionError raised by os.makedirs should propagate."""
cli_create._generate_files(str(agent_folder), model="gemini-2.0-flash-001") monkeypatch.setattr(
os, "makedirs", lambda *a, **k: (_ for _ in ()).throw(PermissionError())
)
with pytest.raises(PermissionError):
cli_create._generate_files(str(agent_folder), model="gemini-2.0-flash-001")
def test_generate_files_no_params(agent_folder: Path) -> None: def test_generate_files_no_params(agent_folder: Path) -> None:
"""No backend parameters → minimal .env file is generated.""" """No backend parameters → minimal .env file is generated."""
cli_create._generate_files(str(agent_folder), model="gemini-2.0-flash-001") cli_create._generate_files(str(agent_folder), model="gemini-2.0-flash-001")
env_content = (agent_folder / ".env").read_text() env_content = (agent_folder / ".env").read_text()
for key in ("GOOGLE_API_KEY", "GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_LOCATION", "GOOGLE_GENAI_USE_VERTEXAI"): for key in (
assert key not in env_content "GOOGLE_API_KEY",
"GOOGLE_CLOUD_PROJECT",
"GOOGLE_CLOUD_LOCATION",
"GOOGLE_GENAI_USE_VERTEXAI",
):
assert key not in env_content
# run_cmd # run_cmd
def test_run_cmd_overwrite_reject(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: def test_run_cmd_overwrite_reject(
"""User rejecting overwrite should trigger click.Abort.""" monkeypatch: pytest.MonkeyPatch, tmp_path: Path
agent_name = "agent" ) -> None:
agent_dir = tmp_path / agent_name """User rejecting overwrite should trigger click.Abort."""
agent_dir.mkdir() agent_name = "agent"
(agent_dir / "dummy.txt").write_text("dummy") agent_dir = tmp_path / agent_name
agent_dir.mkdir()
(agent_dir / "dummy.txt").write_text("dummy")
monkeypatch.setattr(os, "getcwd", lambda: str(tmp_path)) monkeypatch.setattr(os, "getcwd", lambda: str(tmp_path))
monkeypatch.setattr(os.path, "exists", lambda _p: True) monkeypatch.setattr(os.path, "exists", lambda _p: True)
monkeypatch.setattr(os, "listdir", lambda _p: ["dummy.txt"]) monkeypatch.setattr(os, "listdir", lambda _p: ["dummy.txt"])
monkeypatch.setattr(click, "confirm", lambda *a, **k: False) monkeypatch.setattr(click, "confirm", lambda *a, **k: False)
with pytest.raises(click.Abort): with pytest.raises(click.Abort):
cli_create.run_cmd( cli_create.run_cmd(
agent_name, agent_name,
model="gemini-2.0-flash-001", model="gemini-2.0-flash-001",
google_api_key=None, google_api_key=None,
google_cloud_project=None, google_cloud_project=None,
google_cloud_region=None, google_cloud_region=None,
) )
# Prompt helpers # Prompt helpers
def test_prompt_for_google_cloud(monkeypatch: pytest.MonkeyPatch) -> None: def test_prompt_for_google_cloud(monkeypatch: pytest.MonkeyPatch) -> None:
"""Prompt should return the project input.""" """Prompt should return the project input."""
monkeypatch.setattr(click, "prompt", lambda *a, **k: "test-proj") monkeypatch.setattr(click, "prompt", lambda *a, **k: "test-proj")
assert cli_create._prompt_for_google_cloud(None) == "test-proj" assert cli_create._prompt_for_google_cloud(None) == "test-proj"
def test_prompt_for_google_cloud_region(monkeypatch: pytest.MonkeyPatch) -> None: def test_prompt_for_google_cloud_region(
"""Prompt should return the region input.""" monkeypatch: pytest.MonkeyPatch,
monkeypatch.setattr(click, "prompt", lambda *a, **k: "asia-northeast1") ) -> None:
assert cli_create._prompt_for_google_cloud_region(None) == "asia-northeast1" """Prompt should return the region input."""
monkeypatch.setattr(click, "prompt", lambda *a, **k: "asia-northeast1")
assert cli_create._prompt_for_google_cloud_region(None) == "asia-northeast1"
def test_prompt_for_google_api_key(monkeypatch: pytest.MonkeyPatch) -> None: def test_prompt_for_google_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
"""Prompt should return the API-key input.""" """Prompt should return the API-key input."""
monkeypatch.setattr(click, "prompt", lambda *a, **k: "api-key") monkeypatch.setattr(click, "prompt", lambda *a, **k: "api-key")
assert cli_create._prompt_for_google_api_key(None) == "api-key" assert cli_create._prompt_for_google_api_key(None) == "api-key"
def test_prompt_for_model_gemini(monkeypatch: pytest.MonkeyPatch) -> None: def test_prompt_for_model_gemini(monkeypatch: pytest.MonkeyPatch) -> None:
"""Selecting option '1' should return the default Gemini model string.""" """Selecting option '1' should return the default Gemini model string."""
monkeypatch.setattr(click, "prompt", lambda *a, **k: "1") monkeypatch.setattr(click, "prompt", lambda *a, **k: "1")
assert cli_create._prompt_for_model() == "gemini-2.0-flash-001" assert cli_create._prompt_for_model() == "gemini-2.0-flash-001"
def test_prompt_for_model_other(monkeypatch: pytest.MonkeyPatch) -> None: def test_prompt_for_model_other(monkeypatch: pytest.MonkeyPatch) -> None:
"""Selecting option '2' should return placeholder and call secho.""" """Selecting option '2' should return placeholder and call secho."""
called: Dict[str, bool] = {} called: Dict[str, bool] = {}
monkeypatch.setattr(click, "prompt", lambda *a, **k: "2") monkeypatch.setattr(click, "prompt", lambda *a, **k: "2")
def _fake_secho(*_a: Any, **_k: Any) -> None: def _fake_secho(*_a: Any, **_k: Any) -> None:
called["secho"] = True called["secho"] = True
monkeypatch.setattr(click, "secho", _fake_secho)
assert cli_create._prompt_for_model() == "<FILL_IN_MODEL>"
assert called.get("secho") is True
monkeypatch.setattr(click, "secho", _fake_secho)
assert cli_create._prompt_for_model() == "<FILL_IN_MODEL>"
assert called.get("secho") is True
# Backend selection helper # Backend selection helper
def test_prompt_to_choose_backend_api(monkeypatch: pytest.MonkeyPatch) -> None: def test_prompt_to_choose_backend_api(monkeypatch: pytest.MonkeyPatch) -> None:
"""Choosing API-key backend returns (api_key, None, None).""" """Choosing API-key backend returns (api_key, None, None)."""
monkeypatch.setattr(click, "prompt", lambda *a, **k: "1") monkeypatch.setattr(click, "prompt", lambda *a, **k: "1")
monkeypatch.setattr(cli_create, "_prompt_for_google_api_key", lambda _v: "api-key") monkeypatch.setattr(
cli_create, "_prompt_for_google_api_key", lambda _v: "api-key"
)
api_key, proj, region = cli_create._prompt_to_choose_backend(None, None, None) api_key, proj, region = cli_create._prompt_to_choose_backend(None, None, None)
assert api_key == "api-key" assert api_key == "api-key"
assert proj is None and region is None assert proj is None and region is None
def test_prompt_to_choose_backend_vertex(monkeypatch: pytest.MonkeyPatch) -> None: def test_prompt_to_choose_backend_vertex(
"""Choosing Vertex backend returns (None, project, region).""" monkeypatch: pytest.MonkeyPatch,
monkeypatch.setattr(click, "prompt", lambda *a, **k: "2") ) -> None:
monkeypatch.setattr(cli_create, "_prompt_for_google_cloud", lambda _v: "proj") """Choosing Vertex backend returns (None, project, region)."""
monkeypatch.setattr(cli_create, "_prompt_for_google_cloud_region", lambda _v: "region") monkeypatch.setattr(click, "prompt", lambda *a, **k: "2")
monkeypatch.setattr(cli_create, "_prompt_for_google_cloud", lambda _v: "proj")
api_key, proj, region = cli_create._prompt_to_choose_backend(None, None, None) monkeypatch.setattr(
assert api_key is None cli_create, "_prompt_for_google_cloud_region", lambda _v: "region"
assert proj == "proj" )
assert region == "region"
api_key, proj, region = cli_create._prompt_to_choose_backend(None, None, None)
assert api_key is None
assert proj == "proj"
assert region == "region"
# prompt_str # prompt_str
def test_prompt_str_non_empty(monkeypatch: pytest.MonkeyPatch) -> None: def test_prompt_str_non_empty(monkeypatch: pytest.MonkeyPatch) -> None:
"""_prompt_str should retry until a non-blank string is provided.""" """_prompt_str should retry until a non-blank string is provided."""
responses = iter(["", " ", "valid"]) responses = iter(["", " ", "valid"])
monkeypatch.setattr(click, "prompt", lambda *_a, **_k: next(responses)) monkeypatch.setattr(click, "prompt", lambda *_a, **_k: next(responses))
assert cli_create._prompt_str("dummy") == "valid" assert cli_create._prompt_str("dummy") == "valid"
# gcloud fallback helpers # gcloud fallback helpers
def test_get_gcp_project_from_gcloud_fail(monkeypatch: pytest.MonkeyPatch) -> None: def test_get_gcp_project_from_gcloud_fail(
"""Failure of gcloud project lookup should return empty string.""" monkeypatch: pytest.MonkeyPatch,
monkeypatch.setattr( ) -> None:
subprocess, """Failure of gcloud project lookup should return empty string."""
"run", monkeypatch.setattr(
lambda *_a, **_k: (_ for _ in ()).throw(FileNotFoundError()), subprocess,
) "run",
assert cli_create._get_gcp_project_from_gcloud() == "" lambda *_a, **_k: (_ for _ in ()).throw(FileNotFoundError()),
)
assert cli_create._get_gcp_project_from_gcloud() == ""
def test_get_gcp_region_from_gcloud_fail(monkeypatch: pytest.MonkeyPatch) -> None: def test_get_gcp_region_from_gcloud_fail(
"""CalledProcessError should result in empty region string.""" monkeypatch: pytest.MonkeyPatch,
monkeypatch.setattr( ) -> None:
subprocess, """CalledProcessError should result in empty region string."""
"run", monkeypatch.setattr(
lambda *_a, **_k: (_ for _ in ()).throw(subprocess.CalledProcessError(1, "gcloud")), subprocess,
) "run",
assert cli_create._get_gcp_region_from_gcloud() == "" lambda *_a, **_k: (_ for _ in ()).throw(
subprocess.CalledProcessError(1, "gcloud")
),
)
assert cli_create._get_gcp_region_from_gcloud() == ""

View File

@ -17,70 +17,74 @@
from __future__ import annotations from __future__ import annotations
import click from pathlib import Path
import shutil import shutil
import pytest
import subprocess import subprocess
import tempfile import tempfile
import types import types
from typing import Any
import google.adk.cli.cli_deploy as cli_deploy from typing import Callable
from typing import Dict
from pathlib import Path from typing import List
from typing import Any, Callable, Dict, List, Tuple from typing import Tuple
from unittest import mock from unittest import mock
import click
import google.adk.cli.cli_deploy as cli_deploy
import pytest
# Helpers # Helpers
class _Recorder: class _Recorder:
"""A callable object that records every invocation.""" """A callable object that records every invocation."""
def __init__(self) -> None: def __init__(self) -> None:
self.calls: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] self.calls: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = []
def __call__(self, *args: Any, **kwargs: Any) -> None: def __call__(self, *args: Any, **kwargs: Any) -> None:
self.calls.append((args, kwargs)) self.calls.append((args, kwargs))
# Fixtures # Fixtures
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _mute_click(monkeypatch: pytest.MonkeyPatch) -> None: def _mute_click(monkeypatch: pytest.MonkeyPatch) -> None:
"""Suppress click.echo to keep test output clean.""" """Suppress click.echo to keep test output clean."""
monkeypatch.setattr(click, "echo", lambda *a, **k: None) monkeypatch.setattr(click, "echo", lambda *a, **k: None)
@pytest.fixture() @pytest.fixture()
def agent_dir(tmp_path: Path) -> Callable[[bool], Path]: def agent_dir(tmp_path: Path) -> Callable[[bool], Path]:
"""Return a factory that creates a dummy agent directory tree.""" """Return a factory that creates a dummy agent directory tree."""
def _factory(include_requirements: bool) -> Path: def _factory(include_requirements: bool) -> Path:
base = tmp_path / "agent" base = tmp_path / "agent"
base.mkdir() base.mkdir()
(base / "agent.py").write_text("# dummy agent") (base / "agent.py").write_text("# dummy agent")
(base / "__init__.py").touch() (base / "__init__.py").touch()
if include_requirements: if include_requirements:
(base / "requirements.txt").write_text("pytest\n") (base / "requirements.txt").write_text("pytest\n")
return base return base
return _factory return _factory
# _resolve_project # _resolve_project
def test_resolve_project_with_option() -> None: def test_resolve_project_with_option() -> None:
"""It should return the explicit project value untouched.""" """It should return the explicit project value untouched."""
assert cli_deploy._resolve_project("my-project") == "my-project" assert cli_deploy._resolve_project("my-project") == "my-project"
def test_resolve_project_from_gcloud(monkeypatch: pytest.MonkeyPatch) -> None: def test_resolve_project_from_gcloud(monkeypatch: pytest.MonkeyPatch) -> None:
"""It should fall back to `gcloud config get-value project` when no value supplied.""" """It should fall back to `gcloud config get-value project` when no value supplied."""
monkeypatch.setattr( monkeypatch.setattr(
subprocess, subprocess,
"run", "run",
lambda *a, **k: types.SimpleNamespace(stdout="gcp-proj\n"), lambda *a, **k: types.SimpleNamespace(stdout="gcp-proj\n"),
) )
with mock.patch("click.echo") as mocked_echo: with mock.patch("click.echo") as mocked_echo:
assert cli_deploy._resolve_project(None) == "gcp-proj" assert cli_deploy._resolve_project(None) == "gcp-proj"
mocked_echo.assert_called_once() mocked_echo.assert_called_once()
# to_cloud_run # to_cloud_run
@ -90,81 +94,83 @@ def test_to_cloud_run_happy_path(
agent_dir: Callable[[bool], Path], agent_dir: Callable[[bool], Path],
include_requirements: bool, include_requirements: bool,
) -> None: ) -> None:
""" """
End-to-end execution test for `to_cloud_run` covering both presence and End-to-end execution test for `to_cloud_run` covering both presence and
absence of *requirements.txt*. absence of *requirements.txt*.
""" """
tmp_dir = Path(tempfile.mkdtemp()) tmp_dir = Path(tempfile.mkdtemp())
src_dir = agent_dir(include_requirements) src_dir = agent_dir(include_requirements)
copy_recorder = _Recorder() copy_recorder = _Recorder()
run_recorder = _Recorder() run_recorder = _Recorder()
# Cache the ORIGINAL copytree before patching # Cache the ORIGINAL copytree before patching
original_copytree = cli_deploy.shutil.copytree original_copytree = cli_deploy.shutil.copytree
def _recording_copytree(*args: Any, **kwargs: Any): def _recording_copytree(*args: Any, **kwargs: Any):
copy_recorder(*args, **kwargs) copy_recorder(*args, **kwargs)
return original_copytree(*args, **kwargs) return original_copytree(*args, **kwargs)
monkeypatch.setattr(cli_deploy.shutil, "copytree", _recording_copytree) monkeypatch.setattr(cli_deploy.shutil, "copytree", _recording_copytree)
# Skip actual cleanup so that we can inspect generated files later. # Skip actual cleanup so that we can inspect generated files later.
monkeypatch.setattr(cli_deploy.shutil, "rmtree", lambda *_a, **_k: None) monkeypatch.setattr(cli_deploy.shutil, "rmtree", lambda *_a, **_k: None)
monkeypatch.setattr(subprocess, "run", run_recorder) monkeypatch.setattr(subprocess, "run", run_recorder)
cli_deploy.to_cloud_run( cli_deploy.to_cloud_run(
agent_folder=str(src_dir), agent_folder=str(src_dir),
project="proj", project="proj",
region="asia-northeast1", region="asia-northeast1",
service_name="svc", service_name="svc",
app_name="app", app_name="app",
temp_folder=str(tmp_dir), temp_folder=str(tmp_dir),
port=8080, port=8080,
trace_to_cloud=True, trace_to_cloud=True,
with_ui=True, with_ui=True,
verbosity="info", verbosity="info",
session_db_url="sqlite://", session_db_url="sqlite://",
adk_version="0.0.5", adk_version="0.0.5",
) )
# Assertions # Assertions
assert len(copy_recorder.calls) == 1, "Agent sources must be copied exactly once." assert (
assert run_recorder.calls, "gcloud command should be executed at least once." len(copy_recorder.calls) == 1
assert (tmp_dir / "Dockerfile").exists(), "Dockerfile must be generated." ), "Agent sources must be copied exactly once."
assert run_recorder.calls, "gcloud command should be executed at least once."
assert (tmp_dir / "Dockerfile").exists(), "Dockerfile must be generated."
# Manual cleanup because we disabled rmtree in the monkeypatch. # Manual cleanup because we disabled rmtree in the monkeypatch.
shutil.rmtree(tmp_dir, ignore_errors=True) shutil.rmtree(tmp_dir, ignore_errors=True)
def test_to_cloud_run_cleans_temp_dir( def test_to_cloud_run_cleans_temp_dir(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
agent_dir: Callable[[bool], Path], agent_dir: Callable[[bool], Path],
) -> None: ) -> None:
"""`to_cloud_run` should always delete the temporary folder on exit.""" """`to_cloud_run` should always delete the temporary folder on exit."""
tmp_dir = Path(tempfile.mkdtemp()) tmp_dir = Path(tempfile.mkdtemp())
src_dir = agent_dir(False) src_dir = agent_dir(False)
deleted: Dict[str, Path] = {} deleted: Dict[str, Path] = {}
def _fake_rmtree(path: str | Path, *a: Any, **k: Any) -> None: def _fake_rmtree(path: str | Path, *a: Any, **k: Any) -> None:
deleted["path"] = Path(path) deleted["path"] = Path(path)
monkeypatch.setattr(cli_deploy.shutil, "rmtree", _fake_rmtree) monkeypatch.setattr(cli_deploy.shutil, "rmtree", _fake_rmtree)
monkeypatch.setattr(subprocess, "run", _Recorder()) monkeypatch.setattr(subprocess, "run", _Recorder())
cli_deploy.to_cloud_run( cli_deploy.to_cloud_run(
agent_folder=str(src_dir), agent_folder=str(src_dir),
project="proj", project="proj",
region=None, region=None,
service_name="svc", service_name="svc",
app_name="app", app_name="app",
temp_folder=str(tmp_dir), temp_folder=str(tmp_dir),
port=8080, port=8080,
trace_to_cloud=False, trace_to_cloud=False,
with_ui=False, with_ui=False,
verbosity="info", verbosity="info",
session_db_url=None, session_db_url=None,
adk_version="0.0.5", adk_version="0.0.5",
) )
assert deleted["path"] == tmp_dir assert deleted["path"] == tmp_dir

View File

@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

View File

@ -137,11 +137,15 @@ def test_get_input_files_not_exists(empty_state: State):
def test_add_input_files_new(empty_state: State): def test_add_input_files_new(empty_state: State):
"""Test adding input files to an empty session state.""" """Test adding input files to an empty session state."""
ctx = CodeExecutorContext(empty_state) ctx = CodeExecutorContext(empty_state)
new_files = [File(name="new.dat", content="Yg==", mime_type="application/octet-stream")] new_files = [
ctx.add_input_files(new_files) File(name="new.dat", content="Yg==", mime_type="application/octet-stream")
assert empty_state["_code_executor_input_files"] == [
{"name": "new.dat", "content": "Yg==", "mime_type": "application/octet-stream"}
] ]
ctx.add_input_files(new_files)
assert empty_state["_code_executor_input_files"] == [{
"name": "new.dat",
"content": "Yg==",
"mime_type": "application/octet-stream",
}]
def test_add_input_files_append(context_with_data: CodeExecutorContext): def test_add_input_files_append(context_with_data: CodeExecutorContext):
@ -239,9 +243,7 @@ def test_reset_error_count_no_error_key(empty_state: State):
def test_update_code_execution_result_new_invocation(empty_state: State): def test_update_code_execution_result_new_invocation(empty_state: State):
"""Test updating code execution result for a new invocation.""" """Test updating code execution result for a new invocation."""
ctx = CodeExecutorContext(empty_state) ctx = CodeExecutorContext(empty_state)
ctx.update_code_execution_result( ctx.update_code_execution_result("inv1", "print('hi')", "hi", "")
"inv1", "print('hi')", "hi", ""
)
results = empty_state["_code_execution_results"]["inv1"] results = empty_state["_code_execution_results"]["inv1"]
assert len(results) == 1 assert len(results) == 1
assert results[0]["code"] == "print('hi')" assert results[0]["code"] == "print('hi')"
@ -272,4 +274,4 @@ def test_update_code_execution_result_append(
assert len(results) == 2 assert len(results) == 2
assert results[1]["code"] == "new_code" assert results[1]["code"] == "new_code"
assert results[1]["result_stdout"] == "new_out" assert results[1]["result_stdout"] == "new_out"
assert results[1]["result_stderr"] == "new_err" assert results[1]["result_stderr"] == "new_err"

View File

@ -15,6 +15,7 @@
"""Testings for the Trajectory Evaluator.""" """Testings for the Trajectory Evaluator."""
import math import math
from google.adk.evaluation.trajectory_evaluator import TrajectoryEvaluator from google.adk.evaluation.trajectory_evaluator import TrajectoryEvaluator
import pytest import pytest

View File

@ -18,7 +18,8 @@ import os
import sys import sys
import time import time
import types as ptypes import types as ptypes
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock
from unittest.mock import patch
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from google.adk.agents.base_agent import BaseAgent from google.adk.agents.base_agent import BaseAgent
@ -31,7 +32,6 @@ from google.adk.sessions.base_session_service import ListSessionsResponse
from google.genai import types from google.genai import types
import pytest import pytest
# Configure logging to help diagnose server startup issues # Configure logging to help diagnose server startup issues
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,

View File

@ -100,6 +100,7 @@ async def test_function_system_instruction():
" test_id." " test_id."
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_function_system_instruction(): async def test_async_function_system_instruction():
async def build_function_instruction( async def build_function_instruction(

View File

@ -15,6 +15,7 @@
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
from unittest.mock import Mock from unittest.mock import Mock
from google.adk.models.lite_llm import _content_to_message_param from google.adk.models.lite_llm import _content_to_message_param
from google.adk.models.lite_llm import _function_declaration_to_tool_param from google.adk.models.lite_llm import _function_declaration_to_tool_param
from google.adk.models.lite_llm import _get_content from google.adk.models.lite_llm import _get_content
@ -169,6 +170,7 @@ STREAMING_MODEL_RESPONSE = [
), ),
] ]
@pytest.fixture @pytest.fixture
def mock_response(): def mock_response():
return ModelResponse( return ModelResponse(
@ -264,57 +266,59 @@ async def test_generate_content_async(mock_acompletion, lite_llm_instance):
litellm_append_user_content_test_cases = [ litellm_append_user_content_test_cases = [
pytest.param( pytest.param(
LlmRequest( LlmRequest(
contents=[ contents=[
types.Content( types.Content(
role="developer", role="developer",
parts=[types.Part.from_text(text="Test prompt")] parts=[types.Part.from_text(text="Test prompt")],
) )
] ]
),
2,
id="litellm request without user content"
),
pytest.param(
LlmRequest(
contents=[
types.Content(
role="user",
parts=[types.Part.from_text(text="user prompt")]
)
]
),
1,
id="litellm request with user content"
),
pytest.param(
LlmRequest(
contents=[
types.Content(
role="model",
parts=[types.Part.from_text(text="model prompt")]
), ),
types.Content( 2,
role="user", id="litellm request without user content",
parts=[types.Part.from_text(text="user prompt")] ),
), pytest.param(
types.Content( LlmRequest(
role="model", contents=[
parts=[types.Part.from_text(text="model prompt")] types.Content(
) role="user",
] parts=[types.Part.from_text(text="user prompt")],
)
]
),
1,
id="litellm request with user content",
),
pytest.param(
LlmRequest(
contents=[
types.Content(
role="model",
parts=[types.Part.from_text(text="model prompt")],
),
types.Content(
role="user",
parts=[types.Part.from_text(text="user prompt")],
),
types.Content(
role="model",
parts=[types.Part.from_text(text="model prompt")],
),
]
),
4,
id="user content is not the last message scenario",
), ),
4,
id="user content is not the last message scenario"
)
] ]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"llm_request, expected_output", "llm_request, expected_output", litellm_append_user_content_test_cases
litellm_append_user_content_test_cases
) )
def test_maybe_append_user_content(lite_llm_instance, llm_request, expected_output): def test_maybe_append_user_content(
lite_llm_instance, llm_request, expected_output
):
lite_llm_instance._maybe_append_user_content(llm_request) lite_llm_instance._maybe_append_user_content(llm_request)

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import enum import enum
import pytest
from google.adk.events import Event from google.adk.events import Event
from google.adk.events import EventActions from google.adk.events import EventActions
@ -21,6 +20,7 @@ from google.adk.sessions import DatabaseSessionService
from google.adk.sessions import InMemorySessionService from google.adk.sessions import InMemorySessionService
from google.adk.sessions.base_session_service import GetSessionConfig from google.adk.sessions.base_session_service import GetSessionConfig
from google.genai import types from google.genai import types
import pytest
class SessionServiceType(enum.Enum): class SessionServiceType(enum.Enum):

View File

@ -24,7 +24,6 @@ from google.adk.sessions import VertexAiSessionService
from google.genai import types from google.genai import types
import pytest import pytest
MOCK_SESSION_JSON_1 = { MOCK_SESSION_JSON_1 = {
'name': ( 'name': (
'projects/test-project/locations/test-location/' 'projects/test-project/locations/test-location/'

View File

@ -14,7 +14,9 @@
import base64 import base64
import json import json
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock
from unittest.mock import patch
from google.adk.tools.apihub_tool.clients.apihub_client import APIHubClient from google.adk.tools.apihub_tool.clients.apihub_client import APIHubClient
import pytest import pytest
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
@ -464,9 +466,7 @@ class TestAPIHubClient:
MagicMock( MagicMock(
status_code=200, status_code=200,
json=lambda: { json=lambda: {
"name": ( "name": "projects/test-project/locations/us-central1/apis/api1/versions/v1",
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
),
"specs": [], "specs": [],
}, },
), # No specs ), # No specs

View File

@ -16,7 +16,8 @@ from typing import Any
from typing import Dict from typing import Dict
from typing import List from typing import List
from fastapi.openapi.models import Response, Schema from fastapi.openapi.models import Response
from fastapi.openapi.models import Schema
from google.adk.tools.openapi_tool.common.common import ApiParameter from google.adk.tools.openapi_tool.common.common import ApiParameter
from google.adk.tools.openapi_tool.common.common import PydocHelper from google.adk.tools.openapi_tool.common.common import PydocHelper
from google.adk.tools.openapi_tool.common.common import rename_python_keywords from google.adk.tools.openapi_tool.common.common import rename_python_keywords

View File

@ -371,9 +371,7 @@ def test_parse_external_ref_raises_error(openapi_spec_generator):
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": ( "$ref": "external_file.json#/components/schemas/ExternalSchema"
"external_file.json#/components/schemas/ExternalSchema"
)
} }
} }
}, },

View File

@ -14,9 +14,11 @@
import json import json
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock
from unittest.mock import patch
from fastapi.openapi.models import MediaType, Operation from fastapi.openapi.models import MediaType
from fastapi.openapi.models import Operation
from fastapi.openapi.models import Parameter as OpenAPIParameter from fastapi.openapi.models import Parameter as OpenAPIParameter
from fastapi.openapi.models import RequestBody from fastapi.openapi.models import RequestBody
from fastapi.openapi.models import Schema as OpenAPISchema from fastapi.openapi.models import Schema as OpenAPISchema
@ -25,13 +27,13 @@ from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_cred
from google.adk.tools.openapi_tool.common.common import ApiParameter from google.adk.tools.openapi_tool.common.common import ApiParameter
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import ( from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
RestApiTool, from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import snake_to_lower_camel
snake_to_lower_camel, from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
to_gemini_schema,
)
from google.adk.tools.tool_context import ToolContext from google.adk.tools.tool_context import ToolContext
from google.genai.types import FunctionDeclaration, Schema, Type from google.genai.types import FunctionDeclaration
from google.genai.types import Schema
from google.genai.types import Type
import pytest import pytest

View File

@ -161,11 +161,9 @@ async def test_run_async_1_missing_arg_sync_func():
args = {"arg1": "test_value_1"} args = {"arg1": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock()) result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == { assert result == {
"error": ( "error": """Invoking `function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
"""Invoking `function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg2 arg2
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
} }
@ -176,11 +174,9 @@ async def test_run_async_1_missing_arg_async_func():
args = {"arg2": "test_value_1"} args = {"arg2": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock()) result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == { assert result == {
"error": ( "error": """Invoking `async_function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
"""Invoking `async_function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1 arg1
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
} }
@ -191,13 +187,11 @@ async def test_run_async_3_missing_arg_sync_func():
args = {"arg2": "test_value_1"} args = {"arg2": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock()) result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == { assert result == {
"error": ( "error": """Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1 arg1
arg3 arg3
arg4 arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
} }
@ -208,13 +202,11 @@ async def test_run_async_3_missing_arg_async_func():
args = {"arg3": "test_value_1"} args = {"arg3": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock()) result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == { assert result == {
"error": ( "error": """Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1 arg1
arg2 arg2
arg4 arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
} }
@ -225,14 +217,12 @@ async def test_run_async_missing_all_arg_sync_func():
args = {} args = {}
result = await tool.run_async(args=args, tool_context=MagicMock()) result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == { assert result == {
"error": ( "error": """Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1 arg1
arg2 arg2
arg3 arg3
arg4 arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
} }
@ -243,14 +233,12 @@ async def test_run_async_missing_all_arg_async_func():
args = {} args = {}
result = await tool.run_async(args=args, tool_context=MagicMock()) result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == { assert result == {
"error": ( "error": """Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1 arg1
arg2 arg2
arg3 arg3
arg4 arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
} }