Merge branch 'main' into #247-OpenAPIToolSet-Considering-Required-parameters

This commit is contained in:
Wei Sun (Jack)
2025-05-01 18:47:28 -07:00
committed by GitHub
58 changed files with 1579 additions and 422 deletions
+7 -7
View File
@@ -44,7 +44,7 @@ Args:
callback_context: MUST be named 'callback_context' (enforced).
Returns:
The content to return to the user. When set, the agent run will skipped and
The content to return to the user. When set, the agent run will be skipped and
the provided content will be returned to user.
"""
@@ -55,8 +55,8 @@ Args:
callback_context: MUST be named 'callback_context' (enforced).
Returns:
The content to return to the user. When set, the agent run will skipped and
the provided content will be appended to event history as agent response.
The content to return to the user. When set, the provided content will be
appended to event history as agent response.
"""
@@ -101,8 +101,8 @@ class BaseAgent(BaseModel):
callback_context: MUST be named 'callback_context' (enforced).
Returns:
The content to return to the user. When set, the agent run will skipped and
the provided content will be returned to user.
The content to return to the user. When set, the agent run will be skipped
and the provided content will be returned to user.
"""
after_agent_callback: Optional[AfterAgentCallback] = None
"""Callback signature that is invoked after the agent run.
@@ -111,8 +111,8 @@ class BaseAgent(BaseModel):
callback_context: MUST be named 'callback_context' (enforced).
Returns:
The content to return to the user. When set, the agent run will skipped and
the provided content will be appended to event history as agent response.
The content to return to the user. When set, the provided content will be
appended to event history as agent response.
"""
@final
@@ -23,7 +23,6 @@ from .readonly_context import ReadonlyContext
if TYPE_CHECKING:
from google.genai import types
from ..events.event import Event
from ..events.event_actions import EventActions
from ..sessions.state import State
from .invocation_context import InvocationContext
+3 -8
View File
@@ -15,12 +15,7 @@
from __future__ import annotations
import logging
from typing import Any
from typing import AsyncGenerator
from typing import Callable
from typing import Literal
from typing import Optional
from typing import Union
from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Optional, Union
from google.genai import types
from pydantic import BaseModel
@@ -62,11 +57,11 @@ AfterModelCallback: TypeAlias = Callable[
]
BeforeToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext],
Optional[dict],
Union[Awaitable[Optional[dict]], Optional[dict]],
]
AfterToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext, dict],
Optional[dict],
Union[Awaitable[Optional[dict]], Optional[dict]],
]
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
+2 -1
View File
@@ -66,7 +66,8 @@ class OAuth2Auth(BaseModelWithConfig):
redirect_uri: Optional[str] = None
auth_response_uri: Optional[str] = None
auth_code: Optional[str] = None
token: Optional[Dict[str, Any]] = None
access_token: Optional[str] = None
refresh_token: Optional[str] = None
class ServiceAccountCredential(BaseModelWithConfig):
+7 -3
View File
@@ -82,7 +82,8 @@ class AuthHandler:
or not auth_credential.oauth2
or not auth_credential.oauth2.client_id
or not auth_credential.oauth2.client_secret
or auth_credential.oauth2.token
or auth_credential.oauth2.access_token
or auth_credential.oauth2.refresh_token
):
return self.auth_config.exchanged_auth_credential
@@ -93,7 +94,7 @@ class AuthHandler:
redirect_uri=auth_credential.oauth2.redirect_uri,
state=auth_credential.oauth2.state,
)
token = client.fetch_token(
tokens = client.fetch_token(
token_endpoint,
authorization_response=auth_credential.oauth2.auth_response_uri,
code=auth_credential.oauth2.auth_code,
@@ -102,7 +103,10 @@ class AuthHandler:
updated_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(token=dict(token)),
oauth2=OAuth2Auth(
access_token=tokens.get("access_token"),
refresh_token=tokens.get("refresh_token"),
),
)
return updated_credential
+1 -1
View File
@@ -29,5 +29,5 @@
<style>html{color-scheme:dark}html{--mat-sys-background:light-dark(#fcf9f8, #131314);--mat-sys-error:light-dark(#ba1a1a, #ffb4ab);--mat-sys-error-container:light-dark(#ffdad6, #93000a);--mat-sys-inverse-on-surface:light-dark(#f3f0f0, #313030);--mat-sys-inverse-primary:light-dark(#c1c7cd, #595f65);--mat-sys-inverse-surface:light-dark(#313030, #e5e2e2);--mat-sys-on-background:light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-error:light-dark(#ffffff, #690005);--mat-sys-on-error-container:light-dark(#410002, #ffdad6);--mat-sys-on-primary:light-dark(#ffffff, #2b3136);--mat-sys-on-primary-container:light-dark(#161c21, #dde3e9);--mat-sys-on-primary-fixed:light-dark(#161c21, #161c21);--mat-sys-on-primary-fixed-variant:light-dark(#41474d, #41474d);--mat-sys-on-secondary:light-dark(#ffffff, #003061);--mat-sys-on-secondary-container:light-dark(#001b3c, #d5e3ff);--mat-sys-on-secondary-fixed:light-dark(#001b3c, #001b3c);--mat-sys-on-secondary-fixed-variant:light-dark(#0f4784, #0f4784);--mat-sys-on-surface:light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-surface-variant:light-dark(#44474a, #e1e2e6);--mat-sys-on-tertiary:light-dark(#ffffff, #2b3136);--mat-sys-on-tertiary-container:light-dark(#161c21, #dde3e9);--mat-sys-on-tertiary-fixed:light-dark(#161c21, #161c21);--mat-sys-on-tertiary-fixed-variant:light-dark(#41474d, #41474d);--mat-sys-outline:light-dark(#74777b, #8e9194);--mat-sys-outline-variant:light-dark(#c4c7ca, #44474a);--mat-sys-primary:light-dark(#595f65, #c1c7cd);--mat-sys-primary-container:light-dark(#dde3e9, #41474d);--mat-sys-primary-fixed:light-dark(#dde3e9, #dde3e9);--mat-sys-primary-fixed-dim:light-dark(#c1c7cd, #c1c7cd);--mat-sys-scrim:light-dark(#000000, #000000);--mat-sys-secondary:light-dark(#305f9d, #a7c8ff);--mat-sys-secondary-container:light-dark(#d5e3ff, #0f4784);--mat-sys-secondary-fixed:light-dark(#d5e3ff, #d5e3ff);--mat-sys-secondary-fixed-dim:light-dark(#a7c8ff, #a7c8ff);--mat-sys-shadow:light-dark(#000000, #000000);--mat-sys-surface:light-dark(#fcf9f8, #131314);--mat-sys-surface-bright:light-dark(#fcf9f8, #393939);--mat-sys-surface-container:light-dark(#f0eded, #201f20);--mat-sys-surface-container-high:light-dark(#eae7e7, #2a2a2a);--mat-sys-surface-container-highest:light-dark(#e5e2e2, #393939);--mat-sys-surface-container-low:light-dark(#f6f3f3, #1c1b1c);--mat-sys-surface-container-lowest:light-dark(#ffffff, #0e0e0e);--mat-sys-surface-dim:light-dark(#dcd9d9, #131314);--mat-sys-surface-tint:light-dark(#595f65, #c1c7cd);--mat-sys-surface-variant:light-dark(#e1e2e6, #44474a);--mat-sys-tertiary:light-dark(#595f65, #c1c7cd);--mat-sys-tertiary-container:light-dark(#dde3e9, #41474d);--mat-sys-tertiary-fixed:light-dark(#dde3e9, #dde3e9);--mat-sys-tertiary-fixed-dim:light-dark(#c1c7cd, #c1c7cd);--mat-sys-neutral-variant20:#2d3134;--mat-sys-neutral10:#1c1b1c}html{--mat-sys-level0:0px 0px 0px 0px rgba(0, 0, 0, .2), 0px 0px 0px 0px rgba(0, 0, 0, .14), 0px 0px 0px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level1:0px 2px 1px -1px rgba(0, 0, 0, .2), 0px 1px 1px 0px rgba(0, 0, 0, .14), 0px 1px 3px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level2:0px 3px 3px -2px rgba(0, 0, 0, .2), 0px 3px 4px 0px rgba(0, 0, 0, .14), 0px 1px 8px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level3:0px 3px 5px -1px rgba(0, 0, 0, .2), 0px 6px 10px 0px rgba(0, 0, 0, .14), 0px 1px 18px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level4:0px 5px 5px -3px rgba(0, 0, 0, .2), 0px 8px 10px 1px rgba(0, 0, 0, .14), 0px 3px 14px 2px rgba(0, 0, 0, .12)}html{--mat-sys-level5:0px 7px 8px -4px rgba(0, 0, 0, .2), 0px 12px 17px 2px rgba(0, 0, 0, .14), 0px 5px 22px 4px rgba(0, 0, 0, .12)}html{--mat-sys-corner-extra-large:28px;--mat-sys-corner-extra-large-top:28px 28px 0 0;--mat-sys-corner-extra-small:4px;--mat-sys-corner-extra-small-top:4px 4px 0 0;--mat-sys-corner-full:9999px;--mat-sys-corner-large:16px;--mat-sys-corner-large-end:0 16px 16px 0;--mat-sys-corner-large-start:16px 0 0 16px;--mat-sys-corner-large-top:16px 16px 0 0;--mat-sys-corner-medium:12px;--mat-sys-corner-none:0;--mat-sys-corner-small:8px}html{--mat-sys-dragged-state-layer-opacity:.16;--mat-sys-focus-state-layer-opacity:.12;--mat-sys-hover-state-layer-opacity:.08;--mat-sys-pressed-state-layer-opacity:.12}html{font-family:Google Sans,Helvetica Neue,sans-serif!important}body{height:100vh;margin:0}:root{--mat-sys-primary:black;--mdc-checkbox-selected-icon-color:white;--mat-sys-background:#131314;--mat-tab-header-active-label-text-color:#8AB4F8;--mat-tab-header-active-hover-label-text-color:#8AB4F8;--mat-tab-header-active-focus-label-text-color:#8AB4F8;--mat-tab-header-label-text-weight:500;--mdc-text-button-label-text-color:#89b4f8}:root{--mdc-dialog-container-color:#2b2b2f}:root{--mdc-dialog-subhead-color:white}:root{--mdc-circular-progress-active-indicator-color:#a8c7fa}:root{--mdc-circular-progress-size:80}</style><link rel="stylesheet" href="styles-4VDSPQ37.css" media="print" onload="this.media='all'"><noscript><link rel="stylesheet" href="styles-4VDSPQ37.css"></noscript></head>
<body>
<app-root></app-root>
<script src="polyfills-FFHMD2TL.js" type="module"></script><script src="main-ZBO76GRM.js" type="module"></script></body>
<script src="polyfills-FFHMD2TL.js" type="module"></script><script src="main-HWIBUY2R.js" type="module"></script></body>
</html>
File diff suppressed because one or more lines are too long
+55 -48
View File
@@ -39,12 +39,12 @@ class InputFile(BaseModel):
async def run_input_file(
app_name: str,
user_id: str,
root_agent: LlmAgent,
artifact_service: BaseArtifactService,
session: Session,
session_service: BaseSessionService,
input_path: str,
) -> None:
) -> Session:
runner = Runner(
app_name=app_name,
agent=root_agent,
@@ -55,9 +55,11 @@ async def run_input_file(
input_file = InputFile.model_validate_json(f.read())
input_file.state['_time'] = datetime.now()
session.state = input_file.state
session = session_service.create_session(
app_name=app_name, user_id=user_id, state=input_file.state
)
for query in input_file.queries:
click.echo(f'user: {query}')
click.echo(f'[user]: {query}')
content = types.Content(role='user', parts=[types.Part(text=query)])
async for event in runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content
@@ -65,23 +67,23 @@ async def run_input_file(
if event.content and event.content.parts:
if text := ''.join(part.text or '' for part in event.content.parts):
click.echo(f'[{event.author}]: {text}')
return session
async def run_interactively(
app_name: str,
root_agent: LlmAgent,
artifact_service: BaseArtifactService,
session: Session,
session_service: BaseSessionService,
) -> None:
runner = Runner(
app_name=app_name,
app_name=session.app_name,
agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
)
while True:
query = input('user: ')
query = input('[user]: ')
if not query or not query.strip():
continue
if query == 'exit':
@@ -100,7 +102,8 @@ async def run_cli(
*,
agent_parent_dir: str,
agent_folder_name: str,
json_file_path: Optional[str] = None,
input_file: Optional[str] = None,
saved_session_file: Optional[str] = None,
save_session: bool,
) -> None:
"""Runs an interactive CLI for a certain agent.
@@ -109,8 +112,11 @@ async def run_cli(
agent_parent_dir: str, the absolute path of the parent folder of the agent
folder.
agent_folder_name: str, the name of the agent folder.
json_file_path: Optional[str], the absolute path to the json file, either
*.input.json or *.session.json.
input_file: Optional[str], the absolute path to the json file that contains
the initial session state and user queries, exclusive with
saved_session_file.
saved_session_file: Optional[str], the absolute path to the json file that
contains a previously saved session, exclusive with input_file.
save_session: bool, whether to save the session on exit.
"""
if agent_parent_dir not in sys.path:
@@ -118,46 +124,50 @@ async def run_cli(
artifact_service = InMemoryArtifactService()
session_service = InMemorySessionService()
session = session_service.create_session(
app_name=agent_folder_name, user_id='test_user'
)
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
agent_module = importlib.import_module(agent_folder_name)
user_id = 'test_user'
session = session_service.create_session(
app_name=agent_folder_name, user_id=user_id
)
root_agent = agent_module.agent.root_agent
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
if json_file_path:
if json_file_path.endswith('.input.json'):
await run_input_file(
app_name=agent_folder_name,
root_agent=root_agent,
artifact_service=artifact_service,
session=session,
session_service=session_service,
input_path=json_file_path,
)
elif json_file_path.endswith('.session.json'):
with open(json_file_path, 'r') as f:
session = Session.model_validate_json(f.read())
for content in session.get_contents():
if content.role == 'user':
print('user: ', content.parts[0].text)
if input_file:
session = await run_input_file(
app_name=agent_folder_name,
user_id=user_id,
root_agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
input_path=input_file,
)
elif saved_session_file:
loaded_session = None
with open(saved_session_file, 'r') as f:
loaded_session = Session.model_validate_json(f.read())
if loaded_session:
for event in loaded_session.events:
session_service.append_event(session, event)
content = event.content
if not content or not content.parts or not content.parts[0].text:
continue
if event.author == 'user':
click.echo(f'[user]: {content.parts[0].text}')
else:
print(content.parts[0].text)
await run_interactively(
agent_folder_name,
root_agent,
artifact_service,
session,
session_service,
)
else:
print(f'Unsupported file type: {json_file_path}')
exit(1)
else:
print(f'Running agent {root_agent.name}, type exit to exit.')
click.echo(f'[{event.author}]: {content.parts[0].text}')
await run_interactively(
root_agent,
artifact_service,
session,
session_service,
)
else:
click.echo(f'Running agent {root_agent.name}, type exit to exit.')
await run_interactively(
agent_folder_name,
root_agent,
artifact_service,
session,
@@ -165,11 +175,8 @@ async def run_cli(
)
if save_session:
if json_file_path:
session_path = json_file_path.replace('.input.json', '.session.json')
else:
session_id = input('Session ID to save: ')
session_path = f'{agent_module_path}/{session_id}.session.json'
session_id = input('Session ID to save: ')
session_path = f'{agent_module_path}/{session_id}.session.json'
# Fetch the session again to get all the details.
session = session_service.get_session(
+6 -1
View File
@@ -54,7 +54,7 @@ COPY "agents/{app_name}/" "/app/agents/{app_name}/"
EXPOSE {port}
CMD adk {command} --port={port} {trace_to_cloud_option} "/app/agents"
CMD adk {command} --port={port} {session_db_option} {trace_to_cloud_option} "/app/agents"
"""
@@ -85,6 +85,7 @@ def to_cloud_run(
trace_to_cloud: bool,
with_ui: bool,
verbosity: str,
session_db_url: str,
):
"""Deploys an agent to Google Cloud Run.
@@ -112,6 +113,7 @@ def to_cloud_run(
trace_to_cloud: Whether to enable Cloud Trace.
with_ui: Whether to deploy with UI.
verbosity: The verbosity level of the CLI.
session_db_url: The database URL to connect the session.
"""
app_name = app_name or os.path.basename(agent_folder)
@@ -144,6 +146,9 @@ def to_cloud_run(
port=port,
command='web' if with_ui else 'api_server',
install_agent_deps=install_agent_deps,
session_db_option=f'--session_db_url={session_db_url}'
if session_db_url
else '',
trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '',
)
dockerfile_path = os.path.join(temp_folder, 'Dockerfile')
+1 -1
View File
@@ -256,7 +256,7 @@ def run_evals(
)
if final_eval_status == EvalStatus.PASSED:
result = "✅ Passsed"
result = "✅ Passed"
else:
result = "❌ Failed"
+78 -13
View File
@@ -96,6 +96,23 @@ def cli_create_cmd(
)
def validate_exclusive(ctx, param, value):
# Store the validated parameters in the context
if not hasattr(ctx, "exclusive_opts"):
ctx.exclusive_opts = {}
# If this option has a value and we've already seen another exclusive option
if value is not None and any(ctx.exclusive_opts.values()):
exclusive_opt = next(key for key, val in ctx.exclusive_opts.items() if val)
raise click.UsageError(
f"Options '{param.name}' and '{exclusive_opt}' cannot be set together."
)
# Record this option's value
ctx.exclusive_opts[param.name] = value is not None
return value
@main.command("run")
@click.option(
"--save_session",
@@ -105,13 +122,43 @@ def cli_create_cmd(
default=False,
help="Optional. Whether to save the session to a json file on exit.",
)
@click.option(
"--replay",
type=click.Path(
exists=True, dir_okay=False, file_okay=True, resolve_path=True
),
help=(
"The json file that contains the initial state of the session and user"
" queries. A new session will be created using this state. And user"
" queries are run againt the newly created session. Users cannot"
" continue to interact with the agent."
),
callback=validate_exclusive,
)
@click.option(
"--resume",
type=click.Path(
exists=True, dir_okay=False, file_okay=True, resolve_path=True
),
help=(
"The json file that contains a previously saved session (by"
"--save_session option). The previous session will be re-displayed. And"
" user can continue to interact with the agent."
),
callback=validate_exclusive,
)
@click.argument(
"agent",
type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True
),
)
def cli_run(agent: str, save_session: bool):
def cli_run(
agent: str,
save_session: bool,
replay: Optional[str],
resume: Optional[str],
):
"""Runs an interactive CLI for a certain agent.
AGENT: The path to the agent source code folder.
@@ -129,6 +176,8 @@ def cli_run(agent: str, save_session: bool):
run_cli(
agent_parent_dir=agent_parent_folder,
agent_folder_name=agent_folder_name,
input_file=replay,
saved_session_file=resume,
save_session=save_session,
)
)
@@ -245,12 +294,13 @@ def cli_eval(
@click.option(
"--session_db_url",
help=(
"Optional. The database URL to store the session.\n\n - Use"
" 'agentengine://<agent_engine_resource_id>' to connect to Vertex"
" managed session service.\n\n - Use 'sqlite://<path_to_sqlite_file>'"
" to connect to a SQLite DB.\n\n - See"
" https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls"
" for more details on supported DB URLs."
"""Optional. The database URL to store the session.
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
),
)
@click.option(
@@ -366,12 +416,13 @@ def cli_web(
@click.option(
"--session_db_url",
help=(
"Optional. The database URL to store the session.\n\n - Use"
" 'agentengine://<agent_engine_resource_id>' to connect to Vertex"
" managed session service.\n\n - Use 'sqlite://<path_to_sqlite_file>'"
" to connect to a SQLite DB.\n\n - See"
" https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls"
" for more details on supported DB URLs."
"""Optional. The database URL to store the session.
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
),
)
@click.option(
@@ -541,6 +592,18 @@ def cli_api_server(
default="WARNING",
help="Optional. Override the default verbosity level.",
)
@click.option(
"--session_db_url",
help=(
"""Optional. The database URL to store the session.
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
),
)
@click.argument(
"agent",
type=click.Path(
@@ -558,6 +621,7 @@ def cli_deploy_cloud_run(
trace_to_cloud: bool,
with_ui: bool,
verbosity: str,
session_db_url: str,
):
"""Deploys an agent to Cloud Run.
@@ -579,6 +643,7 @@ def cli_deploy_cloud_run(
trace_to_cloud=trace_to_cloud,
with_ui=with_ui,
verbosity=verbosity,
session_db_url=session_db_url,
)
except Exception as e:
click.secho(f"Deploy failed: {e}", fg="red", err=True)
+6
View File
@@ -756,6 +756,12 @@ def get_fast_api_app(
except Exception as e:
logger.exception("Error during live websocket communication: %s", e)
traceback.print_exc()
WEBSOCKET_INTERNAL_ERROR_CODE = 1011
WEBSOCKET_MAX_BYTES_FOR_REASON = 123
await websocket.close(
code=WEBSOCKET_INTERNAL_ERROR_CODE,
reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON],
)
finally:
for task in pending:
task.cancel()
+2 -2
View File
@@ -55,7 +55,7 @@ def load_json(file_path: str) -> Union[Dict, List]:
class AgentEvaluator:
"""An evaluator for Agents, mainly intented for helping with test cases."""
"""An evaluator for Agents, mainly intended for helping with test cases."""
@staticmethod
def find_config_for_test_file(test_file: str):
@@ -91,7 +91,7 @@ class AgentEvaluator:
look for 'root_agent' in the loaded module.
eval_dataset: The eval data set. This can be either a string representing
full path to the file containing eval dataset, or a directory that is
recusively explored for all files that have a `.test.json` suffix.
recursively explored for all files that have a `.test.json` suffix.
num_runs: Number of times all entries in the eval dataset should be
assessed.
agent_name: The name of the agent.
@@ -35,7 +35,7 @@ class ResponseEvaluator:
Args:
raw_eval_dataset: The dataset that will be evaluated.
evaluation_criteria: The evaluation criteria to be used. This method
support two criterias, `response_evaluation_score` and
support two criteria, `response_evaluation_score` and
`response_match_score`.
print_detailed_results: Prints detailed results on the console. This is
usually helpful during debugging.
@@ -56,7 +56,7 @@ class ResponseEvaluator:
Value range: [0, 5], where 0 means that the agent's response is not
coherent, while 5 means it is . High values are good.
A note on raw_eval_dataset:
The dataset should be a list session, where each sesssion is represented
The dataset should be a list session, where each session is represented
as a list of interaction that need evaluation. Each evaluation is
represented as a dictionary that is expected to have values for the
following keys:
@@ -31,10 +31,9 @@ class TrajectoryEvaluator:
):
r"""Returns the mean tool use accuracy of the eval dataset.
Tool use accuracy is calculated by comparing the expected and actuall tool
use trajectories. An exact match scores a 1, 0 otherwise. The final number
is an
average of these individual scores.
Tool use accuracy is calculated by comparing the expected and the actual
tool use trajectories. An exact match scores a 1, 0 otherwise. The final
number is an average of these individual scores.
Value range: [0, 1], where 0 is means none of the too use entries aligned,
and 1 would mean all of them aligned. Higher value is good.
@@ -45,7 +44,7 @@ class TrajectoryEvaluator:
usually helpful during debugging.
A note on eval_dataset:
The dataset should be a list session, where each sesssion is represented
The dataset should be a list session, where each session is represented
as a list of interaction that need evaluation. Each evaluation is
represented as a dictionary that is expected to have values for the
following keys:
+9 -4
View File
@@ -48,8 +48,13 @@ class EventActions(BaseModel):
"""The agent is escalating to a higher level agent."""
requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict)
"""Will only be set by a tool response indicating tool request euc.
dict key is the function call id since one function call response (from model)
could correspond to multiple function calls.
dict value is the required auth config.
"""Authentication configurations requested by tool responses.
This field will only be set by a tool response event indicating tool request
auth credential.
- Keys: The function call id. Since one function response event could contain
multiple function responses that correspond to multiple function calls. Each
function call could request different auth configs. This id is used to
identify the function call.
- Values: The requested auth config.
"""
@@ -94,7 +94,7 @@ can answer it.
If another agent is better for answering the question according to its
description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the
question to that agent. When transfering, do not generate any text other than
question to that agent. When transferring, do not generate any text other than
the function call.
"""
@@ -115,7 +115,7 @@ class BaseLlmFlow(ABC):
yield event
# send back the function response
if event.get_function_responses():
logger.debug('Sending back last function resonse event: %s', event)
logger.debug('Sending back last function response event: %s', event)
invocation_context.live_request_queue.send_content(event.content)
if (
event.content
+1 -1
View File
@@ -111,7 +111,7 @@ def _rearrange_events_for_latest_function_response(
"""Rearrange the events for the latest function_response.
If the latest function_response is for an async function_call, all events
bewteen the initial function_call and the latest function_response will be
between the initial function_call and the latest function_response will be
removed.
Args:
+38 -18
View File
@@ -151,28 +151,33 @@ async def handle_function_calls_async(
# do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
function_args = function_call.args or {}
function_response = None
# Calls the tool if before_tool_callback does not exist or returns None.
function_response: Optional[dict] = None
# before_tool_callback (sync or async)
if agent.before_tool_callback:
function_response = agent.before_tool_callback(
tool=tool, args=function_args, tool_context=tool_context
)
if inspect.isawaitable(function_response):
function_response = await function_response
if not function_response:
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
# Calls after_tool_callback if it exists.
# after_tool_callback (sync or async)
if agent.after_tool_callback:
new_response = agent.after_tool_callback(
altered_function_response = agent.after_tool_callback(
tool=tool,
args=function_args,
tool_context=tool_context,
tool_response=function_response,
)
if new_response:
function_response = new_response
if inspect.isawaitable(altered_function_response):
altered_function_response = await altered_function_response
if altered_function_response is not None:
function_response = altered_function_response
if tool.is_long_running:
# Allow long running function to return None to not provide function response.
@@ -223,11 +228,17 @@ async def handle_function_calls_live(
# in python debugger.
function_args = function_call.args or {}
function_response = None
# Calls the tool if before_tool_callback does not exist or returns None.
# # Calls the tool if before_tool_callback does not exist or returns None.
# if agent.before_tool_callback:
# function_response = agent.before_tool_callback(
# tool, function_args, tool_context
# )
if agent.before_tool_callback:
function_response = agent.before_tool_callback(
tool, function_args, tool_context
tool=tool, args=function_args, tool_context=tool_context
)
if inspect.isawaitable(function_response):
function_response = await function_response
if not function_response:
function_response = await _process_function_live_helper(
@@ -235,15 +246,26 @@ async def handle_function_calls_live(
)
# Calls after_tool_callback if it exists.
# if agent.after_tool_callback:
# new_response = agent.after_tool_callback(
# tool,
# function_args,
# tool_context,
# function_response,
# )
# if new_response:
# function_response = new_response
if agent.after_tool_callback:
new_response = agent.after_tool_callback(
tool,
function_args,
tool_context,
function_response,
altered_function_response = agent.after_tool_callback(
tool=tool,
args=function_args,
tool_context=tool_context,
tool_response=function_response,
)
if new_response:
function_response = new_response
if inspect.isawaitable(altered_function_response):
altered_function_response = await altered_function_response
if altered_function_response is not None:
function_response = altered_function_response
if tool.is_long_running:
# Allow async function to return None to not provide function response.
@@ -310,9 +332,7 @@ async def _process_function_live_helper(
function_response = {
'status': f'No active streaming function named {function_name} found'
}
elif inspect.isasyncgenfunction(tool.func):
print('is async')
elif hasattr(tool, "func") and inspect.isasyncgenfunction(tool.func):
# for streaming tool use case
# we require the function to be a async generator function
async def run_tool_and_update_queue(tool, function_args, tool_context):
@@ -52,7 +52,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
# Appends global instructions if set.
if (
isinstance(root_agent, LlmAgent) and root_agent.global_instruction
): # not emtpy str
): # not empty str
raw_si = root_agent.canonical_global_instruction(
ReadonlyContext(invocation_context)
)
@@ -60,7 +60,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
llm_request.append_instructions([si])
# Appends agent instructions if set.
if agent.instruction: # not emtpy str
if agent.instruction: # not empty str
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
si = _populate_values(raw_si, invocation_context)
llm_request.append_instructions([si])
@@ -152,7 +152,7 @@ class GeminiLlmConnection(BaseLlmConnection):
):
# TODO: Right now, we just support output_transcription without
# changing interface and data protocol. Later, we can consider to
# support output_transcription as a separete field in LlmResponse.
# support output_transcription as a separate field in LlmResponse.
# Transcription is always considered as partial event
# We rely on other control signals to determine when to yield the
@@ -179,7 +179,7 @@ class GeminiLlmConnection(BaseLlmConnection):
# in case of empty content or parts, we sill surface it
# in case it's an interrupted message, we merge the previous partial
# text. Other we don't merge. because content can be none when model
# safty threshold is triggered
# safety threshold is triggered
if message.server_content.interrupted and text:
yield self.__build_full_text_response(text)
text = ''
+10 -1
View File
@@ -14,7 +14,7 @@
from __future__ import annotations
from typing import Optional
from typing import Any, Optional
from google.genai import types
from pydantic import BaseModel
@@ -37,6 +37,7 @@ class LlmResponse(BaseModel):
error_message: Error message if the response is an error.
interrupted: Flag indicating that LLM was interrupted when generating the
content. Usually it's due to user interruption during a bidi streaming.
custom_metadata: The custom metadata of the LlmResponse.
"""
model_config = ConfigDict(extra='forbid')
@@ -71,6 +72,14 @@ class LlmResponse(BaseModel):
Usually it's due to user interruption during a bidi streaming.
"""
custom_metadata: Optional[dict[str, Any]] = None
"""The custom metadata of the LlmResponse.
An optional key-value pair to label an LlmResponse.
NOTE: the entire dict must be JSON serializable.
"""
@staticmethod
def create(
generate_content_response: types.GenerateContentResponse,
@@ -56,6 +56,7 @@ class BuiltInPlanner(BasePlanner):
llm_request: The LLM request to apply the thinking config to.
"""
if self.thinking_config:
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.thinking_config = self.thinking_config
@override
+29
View File
@@ -0,0 +1,29 @@
"""Utility functions for session service."""
import base64
from typing import Any, Optional
from google.genai import types
def encode_content(content: types.Content):
"""Encodes a content object to a JSON dictionary."""
encoded_content = content.model_dump(exclude_none=True)
for p in encoded_content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64encode(
p["inline_data"]["data"]
).decode("utf-8")
return encoded_content
def decode_content(
content: Optional[dict[str, Any]],
) -> Optional[types.Content]:
"""Decodes a content object from a JSON dictionary."""
if not content:
return None
for p in content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
return types.Content.model_validate(content)
@@ -11,14 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import copy
from datetime import datetime
import json
import logging
from typing import Any
from typing import Optional
from typing import Any, Optional
import uuid
from sqlalchemy import Boolean
@@ -27,6 +24,7 @@ from sqlalchemy import Dialect
from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import func
from sqlalchemy import Text
from sqlalchemy.dialects import mysql
from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
@@ -48,6 +46,7 @@ from typing_extensions import override
from tzlocal import get_localzone
from ..events.event import Event
from . import _session_util
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListEventsResponse
@@ -58,6 +57,9 @@ from .state import State
logger = logging.getLogger(__name__)
DEFAULT_MAX_KEY_LENGTH = 128
DEFAULT_MAX_VARCHAR_LENGTH = 256
class DynamicJSON(TypeDecorator):
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
@@ -70,15 +72,16 @@ class DynamicJSON(TypeDecorator):
def load_dialect_impl(self, dialect: Dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(postgresql.JSONB)
else:
return dialect.type_descriptor(Text) # Default to Text for other dialects
if dialect.name == "mysql":
# Use LONGTEXT for MySQL to address the data too long issue
return dialect.type_descriptor(mysql.LONGTEXT)
return dialect.type_descriptor(Text) # Default to Text for other dialects
def process_bind_param(self, value, dialect: Dialect):
if value is not None:
if dialect.name == "postgresql":
return value # JSONB handles dict directly
else:
return json.dumps(value) # Serialize to JSON string for TEXT
return json.dumps(value) # Serialize to JSON string for TEXT
return value
def process_result_value(self, value, dialect: Dialect):
@@ -92,17 +95,25 @@ class DynamicJSON(TypeDecorator):
class Base(DeclarativeBase):
"""Base class for database tables."""
pass
class StorageSession(Base):
"""Represents a session stored in the database."""
__tablename__ = "sessions"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
id: Mapped[str] = mapped_column(
String, primary_key=True, default=lambda: str(uuid.uuid4())
String(DEFAULT_MAX_KEY_LENGTH),
primary_key=True,
default=lambda: str(uuid.uuid4()),
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
@@ -125,18 +136,29 @@ class StorageSession(Base):
class StorageEvent(Base):
"""Represents an event stored in the database."""
__tablename__ = "events"
id: Mapped[str] = mapped_column(String, primary_key=True)
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
session_id: Mapped[str] = mapped_column(String, primary_key=True)
id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
session_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
invocation_id: Mapped[str] = mapped_column(String)
author: Mapped[str] = mapped_column(String)
branch: Mapped[str] = mapped_column(String, nullable=True)
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
branch: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
)
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON)
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
@@ -147,8 +169,10 @@ class StorageEvent(Base):
)
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
error_code: Mapped[str] = mapped_column(String, nullable=True)
error_message: Mapped[str] = mapped_column(String, nullable=True)
error_code: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
)
error_message: Mapped[str] = mapped_column(String(1024), nullable=True)
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
storage_session: Mapped[StorageSession] = relationship(
@@ -182,9 +206,12 @@ class StorageEvent(Base):
class StorageAppState(Base):
"""Represents an app state stored in the database."""
__tablename__ = "app_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
@@ -192,13 +219,17 @@ class StorageAppState(Base):
DateTime(), default=func.now(), onupdate=func.now()
)
class StorageUserState(Base):
"""Represents a user state stored in the database."""
__tablename__ = "user_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
@@ -217,7 +248,7 @@ class DatabaseSessionService(BaseSessionService):
"""
# 1. Create DB engine for db connection
# 2. Create all tables based on schema
# 3. Initialize all properies
# 3. Initialize all properties
try:
db_engine = create_engine(db_url)
@@ -353,6 +384,7 @@ class DatabaseSessionService(BaseSessionService):
else True
)
.limit(config.num_recent_events if config else None)
.order_by(StorageEvent.timestamp.asc())
.all()
)
@@ -383,7 +415,7 @@ class DatabaseSessionService(BaseSessionService):
author=e.author,
branch=e.branch,
invocation_id=e.invocation_id,
content=_decode_content(e.content),
content=_session_util.decode_content(e.content),
actions=e.actions,
timestamp=e.timestamp.timestamp(),
long_running_tool_ids=e.long_running_tool_ids,
@@ -506,15 +538,7 @@ class DatabaseSessionService(BaseSessionService):
interrupted=event.interrupted,
)
if event.content:
encoded_content = event.content.model_dump(exclude_none=True)
# Workaround for multimodal Content throwing JSON not serializable
# error with SQLAlchemy.
for p in encoded_content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = (
base64.b64encode(p["inline_data"]["data"]).decode("utf-8"),
)
storage_event.content = encoded_content
storage_event.content = _session_util.encode_content(event.content)
sessionFactory.add(storage_event)
@@ -574,10 +598,3 @@ def _merge_state(app_state, user_state, session_state):
for key in user_state.keys():
merged_state[State.USER_PREFIX + key] = user_state[key]
return merged_state
def _decode_content(content: dict[str, Any]) -> dict[str, Any]:
for p in content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
return content
+1 -1
View File
@@ -26,7 +26,7 @@ class State:
"""
Args:
value: The current value of the state dict.
delta: The delta change to the current value that hasn't been commited.
delta: The delta change to the current value that hasn't been committed.
"""
self._value = value
self._delta = delta
@@ -14,21 +14,23 @@
import logging
import re
import time
from typing import Any
from typing import Optional
from typing import Any, Optional
from dateutil.parser import isoparse
from dateutil import parser
from google import genai
from typing_extensions import override
from ..events.event import Event
from ..events.event_actions import EventActions
from . import _session_util
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListEventsResponse
from .base_session_service import ListSessionsResponse
from .session import Session
isoparse = parser.isoparse
logger = logging.getLogger(__name__)
@@ -289,7 +291,7 @@ def _convert_event_to_json(event: Event):
}
event_json['actions'] = actions_json
if event.content:
event_json['content'] = event.content.model_dump(exclude_none=True)
event_json['content'] = _session_util.encode_content(event.content)
if event.error_code:
event_json['error_code'] = event.error_code
if event.error_message:
@@ -316,7 +318,7 @@ def _from_api_event(api_event: dict) -> Event:
invocation_id=api_event['invocationId'],
author=api_event['author'],
actions=event_actions,
content=api_event.get('content', None),
content=_session_util.decode_content(api_event.get('content', None)),
timestamp=isoparse(api_event['timestamp']).timestamp(),
error_code=api_event.get('errorCode', None),
error_message=api_event.get('errorMessage', None),
+2 -3
View File
@@ -45,10 +45,9 @@ class AgentTool(BaseTool):
skip_summarization: Whether to skip summarization of the agent output.
"""
def __init__(self, agent: BaseAgent):
def __init__(self, agent: BaseAgent, skip_summarization: bool = False):
self.agent = agent
self.skip_summarization: bool = False
"""Whether to skip summarization of the agent output."""
self.skip_summarization: bool = skip_summarization
super().__init__(name=agent.name, description=agent.description)
@@ -13,7 +13,9 @@
# limitations under the License.
from .application_integration_toolset import ApplicationIntegrationToolset
from .integration_connector_tool import IntegrationConnectorTool
__all__ = [
'ApplicationIntegrationToolset',
'IntegrationConnectorTool',
]
@@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import List
from typing import Optional
from typing import Dict, List, Optional
from fastapi.openapi.models import HTTPBearer
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from ...auth.auth_credential import AuthCredential
from ...auth.auth_credential import AuthCredentialTypes
from ...auth.auth_credential import ServiceAccount
from ...auth.auth_credential import ServiceAccountCredential
from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential
from ..openapi_tool.openapi_spec_parser.openapi_spec_parser import OpenApiSpecParser
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from .clients.connections_client import ConnectionsClient
from .clients.integration_client import IntegrationClient
from .integration_connector_tool import IntegrationConnectorTool
# TODO(cheliu): Apply a common toolset interface
@@ -168,6 +168,7 @@ class ApplicationIntegrationToolset:
actions,
service_account_json,
)
connection_details = {}
if integration and trigger:
spec = integration_client.get_openapi_spec_for_integration()
elif connection and (entity_operations or actions):
@@ -175,16 +176,6 @@ class ApplicationIntegrationToolset:
project, location, connection, service_account_json
)
connection_details = connections_client.get_connection_details()
tool_instructions += (
"ALWAYS use serviceName = "
+ connection_details["serviceName"]
+ ", host = "
+ connection_details["host"]
+ " and the connection name = "
+ f"projects/{project}/locations/{location}/connections/{connection} when"
" using this tool"
+ ". DONOT ask the user for these values as you already have those."
)
spec = integration_client.get_openapi_spec_for_connection(
tool_name,
tool_instructions,
@@ -194,9 +185,9 @@ class ApplicationIntegrationToolset:
"Either (integration and trigger) or (connection and"
" (entity_operations or actions)) should be provided."
)
self._parse_spec_to_tools(spec)
self._parse_spec_to_tools(spec, connection_details)
def _parse_spec_to_tools(self, spec_dict):
def _parse_spec_to_tools(self, spec_dict, connection_details):
"""Parses the spec dict to a list of RestApiTool."""
if self.service_account_json:
sa_credential = ServiceAccountCredential.model_validate_json(
@@ -218,12 +209,43 @@ class ApplicationIntegrationToolset:
),
)
auth_scheme = HTTPBearer(bearerFormat="JWT")
tools = OpenAPIToolset(
spec_dict=spec_dict,
auth_credential=auth_credential,
auth_scheme=auth_scheme,
).get_tools()
for tool in tools:
if self.integration and self.trigger:
tools = OpenAPIToolset(
spec_dict=spec_dict,
auth_credential=auth_credential,
auth_scheme=auth_scheme,
).get_tools()
for tool in tools:
self.generated_tools[tool.name] = tool
return
operations = OpenApiSpecParser().parse(spec_dict)
for open_api_operation in operations:
operation = getattr(open_api_operation.operation, "x-operation")
entity = None
action = None
if hasattr(open_api_operation.operation, "x-entity"):
entity = getattr(open_api_operation.operation, "x-entity")
elif hasattr(open_api_operation.operation, "x-action"):
action = getattr(open_api_operation.operation, "x-action")
rest_api_tool = RestApiTool.from_parsed_operation(open_api_operation)
if auth_scheme:
rest_api_tool.configure_auth_scheme(auth_scheme)
if auth_credential:
rest_api_tool.configure_auth_credential(auth_credential)
tool = IntegrationConnectorTool(
name=rest_api_tool.name,
description=rest_api_tool.description,
connection_name=connection_details["name"],
connection_host=connection_details["host"],
connection_service_name=connection_details["serviceName"],
entity=entity,
action=action,
operation=operation,
rest_api_tool=rest_api_tool,
)
self.generated_tools[tool.name] = tool
def get_tools(self) -> List[RestApiTool]:
@@ -68,12 +68,14 @@ class ConnectionsClient:
response = self._execute_api_call(url)
connection_data = response.json()
connection_name = connection_data.get("name", "")
service_name = connection_data.get("serviceDirectory", "")
host = connection_data.get("host", "")
if host:
service_name = connection_data.get("tlsServiceDirectory", "")
auth_override_enabled = connection_data.get("authOverrideEnabled", False)
return {
"name": connection_name,
"serviceName": service_name,
"host": host,
"authOverrideEnabled": auth_override_enabled,
@@ -291,13 +293,9 @@ class ConnectionsClient:
tool_name: str = "",
tool_instructions: str = "",
) -> Dict[str, Any]:
description = (
f"Use this tool with" f' action = "{action}" and'
) + f' operation = "{operation}" only. Dont ask these values from user.'
description = f"Use this tool to execute {action}"
if operation == "EXECUTE_QUERY":
description = (
(f"Use this tool with" f' action = "{action}" and')
+ f' operation = "{operation}" only. Dont ask these values from user.'
description += (
" Use pageSize = 50 and timeout = 120 until user specifies a"
" different value otherwise. If user provides a query in natural"
" language, convert it to SQL query and then execute it using the"
@@ -308,6 +306,8 @@ class ConnectionsClient:
"summary": f"{action_display_name}",
"description": f"{description} {tool_instructions}",
"operationId": f"{tool_name}_{action_display_name}",
"x-action": f"{action}",
"x-operation": f"{operation}",
"requestBody": {
"content": {
"application/json": {
@@ -347,16 +347,12 @@ class ConnectionsClient:
"post": {
"summary": f"List {entity}",
"description": (
f"Returns all entities of type {entity}. Use this tool with"
+ f' entity = "{entity}" and'
+ ' operation = "LIST_ENTITIES" only. Dont ask these values'
" from"
+ ' user. Always use ""'
+ ' as filter clause and ""'
+ " as page token and 50 as page size until user specifies a"
" different value otherwise. Use single quotes for strings in"
f" filter clause. {tool_instructions}"
f"""Returns the list of {entity} data. If the page token was available in the response, let users know there are more records available. Ask if the user wants to fetch the next page of results. When passing filter use the
following format: `field_name1='value1' AND field_name2='value2'
`. {tool_instructions}"""
),
"x-operation": "LIST_ENTITIES",
"x-entity": f"{entity}",
"operationId": f"{tool_name}_list_{entity}",
"requestBody": {
"content": {
@@ -401,14 +397,11 @@ class ConnectionsClient:
"post": {
"summary": f"Get {entity}",
"description": (
(
f"Returns the details of the {entity}. Use this tool with"
f' entity = "{entity}" and'
)
+ ' operation = "GET_ENTITY" only. Dont ask these values from'
f" user. {tool_instructions}"
f"Returns the details of the {entity}. {tool_instructions}"
),
"operationId": f"{tool_name}_get_{entity}",
"x-operation": "GET_ENTITY",
"x-entity": f"{entity}",
"requestBody": {
"content": {
"application/json": {
@@ -445,17 +438,10 @@ class ConnectionsClient:
) -> Dict[str, Any]:
return {
"post": {
"summary": f"Create {entity}",
"description": (
(
f"Creates a new entity of type {entity}. Use this tool with"
f' entity = "{entity}" and'
)
+ ' operation = "CREATE_ENTITY" only. Dont ask these values'
" from"
+ " user. Follow the schema of the entity provided in the"
f" instructions to create {entity}. {tool_instructions}"
),
"summary": f"Creates a new {entity}",
"description": f"Creates a new {entity}. {tool_instructions}",
"x-operation": "CREATE_ENTITY",
"x-entity": f"{entity}",
"operationId": f"{tool_name}_create_{entity}",
"requestBody": {
"content": {
@@ -491,18 +477,10 @@ class ConnectionsClient:
) -> Dict[str, Any]:
return {
"post": {
"summary": f"Update {entity}",
"description": (
(
f"Updates an entity of type {entity}. Use this tool with"
f' entity = "{entity}" and'
)
+ ' operation = "UPDATE_ENTITY" only. Dont ask these values'
" from"
+ " user. Use entityId to uniquely identify the entity to"
" update. Follow the schema of the entity provided in the"
f" instructions to update {entity}. {tool_instructions}"
),
"summary": f"Updates the {entity}",
"description": f"Updates the {entity}. {tool_instructions}",
"x-operation": "UPDATE_ENTITY",
"x-entity": f"{entity}",
"operationId": f"{tool_name}_update_{entity}",
"requestBody": {
"content": {
@@ -538,16 +516,10 @@ class ConnectionsClient:
) -> Dict[str, Any]:
return {
"post": {
"summary": f"Delete {entity}",
"description": (
(
f"Deletes an entity of type {entity}. Use this tool with"
f' entity = "{entity}" and'
)
+ ' operation = "DELETE_ENTITY" only. Dont ask these values'
" from"
f" user. {tool_instructions}"
),
"summary": f"Delete the {entity}",
"description": f"Deletes the {entity}. {tool_instructions}",
"x-operation": "DELETE_ENTITY",
"x-entity": f"{entity}",
"operationId": f"{tool_name}_delete_{entity}",
"requestBody": {
"content": {
@@ -0,0 +1,159 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any
from typing import Dict
from typing import Optional
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
from google.genai.types import FunctionDeclaration
from typing_extensions import override
from .. import BaseTool
from ..tool_context import ToolContext
logger = logging.getLogger(__name__)
class IntegrationConnectorTool(BaseTool):
"""A tool that wraps a RestApiTool to interact with a specific Application Integration endpoint.
This tool adds Application Integration specific context like connection
details, entity, operation, and action to the underlying REST API call
handled by RestApiTool. It prepares the arguments and then delegates the
actual API call execution to the contained RestApiTool instance.
* Generates request params and body
* Attaches auth credentials to API call.
Example:
```
# Each API operation in the spec will be turned into its own tool
# Name of the tool is the operationId of that operation, in snake case
operations = OperationGenerator().parse(openapi_spec_dict)
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
```
"""
EXCLUDE_FIELDS = [
'connection_name',
'service_name',
'host',
'entity',
'operation',
'action',
]
OPTIONAL_FIELDS = [
'page_size',
'page_token',
'filter',
]
def __init__(
self,
name: str,
description: str,
connection_name: str,
connection_host: str,
connection_service_name: str,
entity: str,
operation: str,
action: str,
rest_api_tool: RestApiTool,
):
"""Initializes the ApplicationIntegrationTool.
Args:
name: The name of the tool, typically derived from the API operation.
Should be unique and adhere to Gemini function naming conventions
(e.g., less than 64 characters).
description: A description of what the tool does, usually based on the
API operation's summary or description.
connection_name: The name of the Integration Connector connection.
connection_host: The hostname or IP address for the connection.
connection_service_name: The specific service name within the host.
entity: The Integration Connector entity being targeted.
operation: The specific operation being performed on the entity.
action: The action associated with the operation (e.g., 'execute').
rest_api_tool: An initialized RestApiTool instance that handles the
underlying REST API communication based on an OpenAPI specification
operation. This tool will be called by ApplicationIntegrationTool with
added connection and context arguments. tool =
[RestApiTool.from_parsed_operation(o) for o in operations]
"""
# Gemini restrict the length of function name to be less than 64 characters
super().__init__(
name=name,
description=description,
)
self.connection_name = connection_name
self.connection_host = connection_host
self.connection_service_name = connection_service_name
self.entity = entity
self.operation = operation
self.action = action
self.rest_api_tool = rest_api_tool
@override
def _get_declaration(self) -> FunctionDeclaration:
"""Returns the function declaration in the Gemini Schema format."""
schema_dict = self.rest_api_tool._operation_parser.get_json_schema()
for field in self.EXCLUDE_FIELDS:
if field in schema_dict['properties']:
del schema_dict['properties'][field]
for field in self.OPTIONAL_FIELDS + self.EXCLUDE_FIELDS:
if field in schema_dict['required']:
schema_dict['required'].remove(field)
parameters = to_gemini_schema(schema_dict)
function_decl = FunctionDeclaration(
name=self.name, description=self.description, parameters=parameters
)
return function_decl
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
) -> Dict[str, Any]:
args['connection_name'] = self.connection_name
args['service_name'] = self.connection_service_name
args['host'] = self.connection_host
args['entity'] = self.entity
args['operation'] = self.operation
args['action'] = self.action
logger.info('Running tool: %s with args: %s', self.name, args)
return self.rest_api_tool.call(args=args, tool_context=tool_context)
def __str__(self):
return (
f'ApplicationIntegrationTool(name="{self.name}",'
f' description="{self.description}",'
f' connection_name="{self.connection_name}", entity="{self.entity}",'
f' operation="{self.operation}", action="{self.action}")'
)
def __repr__(self):
return (
f'ApplicationIntegrationTool(name="{self.name}",'
f' description="{self.description}",'
f' connection_name="{self.connection_name}",'
f' connection_host="{self.connection_host}",'
f' connection_service_name="{self.connection_service_name}",'
f' entity="{self.entity}", operation="{self.operation}",'
f' action="{self.action}", rest_api_tool={repr(self.rest_api_tool)})'
)
+42
View File
@@ -59,6 +59,23 @@ class FunctionTool(BaseTool):
if 'tool_context' in signature.parameters:
args_to_call['tool_context'] = tool_context
# Before invoking the function, we check for if the list of args passed in
# has all the mandatory arguments or not.
# If the check fails, then we don't invoke the tool and let the Agent know
# that there was a missing a input parameter. This will basically help
# the underlying model fix the issue and retry.
mandatory_args = self._get_mandatory_args()
missing_mandatory_args = [
arg for arg in mandatory_args if arg not in args_to_call
]
if missing_mandatory_args:
missing_mandatory_args_str = '\n'.join(missing_mandatory_args)
error_str = f"""Invoking `{self.name}()` failed as the following mandatory input parameters are not present:
{missing_mandatory_args_str}
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
return {'error': error_str}
if inspect.iscoroutinefunction(self.func):
return await self.func(**args_to_call) or {}
else:
@@ -85,3 +102,28 @@ class FunctionTool(BaseTool):
args_to_call['tool_context'] = tool_context
async for item in self.func(**args_to_call):
yield item
def _get_mandatory_args(
self,
) -> list[str]:
"""Identifies mandatory parameters (those without default values) for a function.
Returns:
A list of strings, where each string is the name of a mandatory parameter.
"""
signature = inspect.signature(self.func)
mandatory_params = []
for name, param in signature.parameters.items():
# A parameter is mandatory if:
# 1. It has no default value (param.default is inspect.Parameter.empty)
# 2. It's not a variable positional (*args) or variable keyword (**kwargs) parameter
#
# For more refer to: https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind
if param.default == inspect.Parameter.empty and param.kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
mandatory_params.append(name)
return mandatory_params
@@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import inspect
import os
from typing import Any
from typing import Dict
from typing import Final
from typing import List
from typing import Optional
@@ -28,6 +30,7 @@ from .googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
class GoogleApiToolSet:
"""Google API Tool Set."""
def __init__(self, tools: List[RestApiTool]):
self.tools: Final[List[GoogleApiTool]] = [
@@ -45,10 +48,10 @@ class GoogleApiToolSet:
@staticmethod
def _load_tool_set_with_oidc_auth(
spec_file: str = None,
spec_dict: Dict[str, Any] = None,
scopes: list[str] = None,
) -> Optional[OpenAPIToolset]:
spec_file: Optional[str] = None,
spec_dict: Optional[dict[str, Any]] = None,
scopes: Optional[list[str]] = None,
) -> OpenAPIToolset:
spec_str = None
if spec_file:
# Get the frame of the caller
@@ -90,18 +93,18 @@ class GoogleApiToolSet:
@classmethod
def load_tool_set(
cl: Type['GoogleApiToolSet'],
cls: Type[GoogleApiToolSet],
api_name: str,
api_version: str,
) -> 'GoogleApiToolSet':
) -> GoogleApiToolSet:
spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert()
scope = list(
spec_dict['components']['securitySchemes']['oauth2']['flows'][
'authorizationCode'
]['scopes'].keys()
)[0]
return cl(
cl._load_tool_set_with_oidc_auth(
return cls(
cls._load_tool_set_with_oidc_auth(
spec_dict=spec_dict, scopes=[scope]
).get_tools()
)
+1 -1
View File
@@ -89,7 +89,7 @@ class LoadArtifactsTool(BaseTool):
than the function call.
"""])
# Attache the content of the artifacts if the model requests them.
# Attach the content of the artifacts if the model requests them.
# This only adds the content to the model request, instead of the session.
if llm_request.contents and llm_request.contents[-1].parts:
function_response = llm_request.contents[-1].parts[0].function_response
@@ -66,10 +66,10 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
Returns:
An AuthCredential object containing the HTTP bearer access token. If the
HTTO bearer token cannot be generated, return the origianl credential
HTTP bearer token cannot be generated, return the original credential.
"""
if "access_token" not in auth_credential.oauth2.token:
if not auth_credential.oauth2.access_token:
return auth_credential
# Return the access token as a bearer token.
@@ -78,7 +78,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(
token=auth_credential.oauth2.token["access_token"]
token=auth_credential.oauth2.access_token
),
),
)
@@ -111,7 +111,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
return auth_credential
# If access token is exchanged, exchange a HTTPBearer token.
if auth_credential.oauth2.token:
if auth_credential.oauth2.access_token:
return self.generate_auth_token(auth_credential)
return None
@@ -124,7 +124,7 @@ class OpenAPIToolset:
def _load_spec(
self, spec_str: str, spec_type: Literal["json", "yaml"]
) -> Dict[str, Any]:
"""Loads the OpenAPI spec string into adictionary."""
"""Loads the OpenAPI spec string into a dictionary."""
if spec_type == "json":
return json.loads(spec_str)
elif spec_type == "yaml":
@@ -14,20 +14,12 @@
import inspect
from textwrap import dedent
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from typing import Any, Dict, List, Optional, Union
from fastapi.encoders import jsonable_encoder
from fastapi.openapi.models import Operation
from fastapi.openapi.models import Parameter
from fastapi.openapi.models import Schema
from fastapi.openapi.models import Operation, Parameter, Schema
from ..common.common import ApiParameter
from ..common.common import PydocHelper
from ..common.common import to_snake_case
from ..common.common import ApiParameter, PydocHelper, to_snake_case
class OperationParser:
@@ -113,7 +105,8 @@ class OperationParser:
description = request_body.description or ''
if schema and schema.type == 'object':
for prop_name, prop_details in schema.properties.items():
properties = schema.properties or {}
for prop_name, prop_details in properties.items():
self.params.append(
ApiParameter(
original_name=prop_name,
@@ -17,6 +17,7 @@ from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
@@ -59,6 +60,40 @@ def snake_to_lower_camel(snake_case_string: str):
])
# TODO: Switch to Gemini `from_json_schema` util when it is released
# in Gemini SDK.
def normalize_json_schema_type(
json_schema_type: Optional[Union[str, Sequence[str]]],
) -> tuple[Optional[str], bool]:
"""Converts a JSON Schema Type into Gemini Schema type.
Adopted and modified from Gemini SDK. This gets the first available schema
type from JSON Schema, and use it to mark Gemini schema type. If JSON Schema
contains a list of types, the first non null type is used.
Remove this after switching to Gemini `from_json_schema`.
"""
if json_schema_type is None:
return None, False
if isinstance(json_schema_type, str):
if json_schema_type == "null":
return None, True
return json_schema_type, False
non_null_types = []
nullable = False
# If json schema type is an array, pick the first non null type.
for type_value in json_schema_type:
if type_value == "null":
nullable = True
else:
non_null_types.append(type_value)
non_null_type = non_null_types[0] if non_null_types else None
return non_null_type, nullable
# TODO: Switch to Gemini `from_json_schema` util when it is released
# in Gemini SDK.
def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
"""Converts an OpenAPI schema dictionary to a Gemini Schema object.
@@ -82,13 +117,6 @@ def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
if not openapi_schema.get("type"):
openapi_schema["type"] = "object"
# Adding this to avoid "properties: should be non-empty for OBJECT type" error
# See b/385165182
if openapi_schema.get("type", "") == "object" and not openapi_schema.get(
"properties"
):
openapi_schema["properties"] = {"dummy_DO_NOT_GENERATE": {"type": "string"}}
for key, value in openapi_schema.items():
snake_case_key = to_snake_case(key)
# Check if the snake_case_key exists in the Schema model's fields.
@@ -99,7 +127,17 @@ def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
# Format: properties[expiration].format: only 'enum' and 'date-time' are
# supported for STRING type
continue
if snake_case_key == "properties" and isinstance(value, dict):
elif snake_case_key == "type":
schema_type, nullable = normalize_json_schema_type(
openapi_schema.get("type", None)
)
# Adding this to force adding a type to an empty dict
# This avoid "... one_of or any_of must specify a type" error
pydantic_schema_data["type"] = schema_type if schema_type else "object"
pydantic_schema_data["type"] = pydantic_schema_data["type"].upper()
if nullable:
pydantic_schema_data["nullable"] = True
elif snake_case_key == "properties" and isinstance(value, dict):
pydantic_schema_data[snake_case_key] = {
k: to_gemini_schema(v) for k, v in value.items()
}
+1 -1
View File
@@ -13,4 +13,4 @@
# limitations under the License.
# version: date+base_cl
__version__ = "0.1.1"
__version__ = "0.3.0"