mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2026-02-04 13:56:24 -06:00
Merge branch 'main' into #247-OpenAPIToolSet-Considering-Required-parameters
This commit is contained in:
@@ -44,7 +44,7 @@ Args:
|
||||
callback_context: MUST be named 'callback_context' (enforced).
|
||||
|
||||
Returns:
|
||||
The content to return to the user. When set, the agent run will skipped and
|
||||
The content to return to the user. When set, the agent run will be skipped and
|
||||
the provided content will be returned to user.
|
||||
"""
|
||||
|
||||
@@ -55,8 +55,8 @@ Args:
|
||||
callback_context: MUST be named 'callback_context' (enforced).
|
||||
|
||||
Returns:
|
||||
The content to return to the user. When set, the agent run will skipped and
|
||||
the provided content will be appended to event history as agent response.
|
||||
The content to return to the user. When set, the provided content will be
|
||||
appended to event history as agent response.
|
||||
"""
|
||||
|
||||
|
||||
@@ -101,8 +101,8 @@ class BaseAgent(BaseModel):
|
||||
callback_context: MUST be named 'callback_context' (enforced).
|
||||
|
||||
Returns:
|
||||
The content to return to the user. When set, the agent run will skipped and
|
||||
the provided content will be returned to user.
|
||||
The content to return to the user. When set, the agent run will be skipped
|
||||
and the provided content will be returned to user.
|
||||
"""
|
||||
after_agent_callback: Optional[AfterAgentCallback] = None
|
||||
"""Callback signature that is invoked after the agent run.
|
||||
@@ -111,8 +111,8 @@ class BaseAgent(BaseModel):
|
||||
callback_context: MUST be named 'callback_context' (enforced).
|
||||
|
||||
Returns:
|
||||
The content to return to the user. When set, the agent run will skipped and
|
||||
the provided content will be appended to event history as agent response.
|
||||
The content to return to the user. When set, the provided content will be
|
||||
appended to event history as agent response.
|
||||
"""
|
||||
|
||||
@final
|
||||
|
||||
@@ -23,7 +23,6 @@ from .readonly_context import ReadonlyContext
|
||||
if TYPE_CHECKING:
|
||||
from google.genai import types
|
||||
|
||||
from ..events.event import Event
|
||||
from ..events.event_actions import EventActions
|
||||
from ..sessions.state import State
|
||||
from .invocation_context import InvocationContext
|
||||
|
||||
@@ -15,12 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import Callable
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Optional, Union
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
@@ -62,11 +57,11 @@ AfterModelCallback: TypeAlias = Callable[
|
||||
]
|
||||
BeforeToolCallback: TypeAlias = Callable[
|
||||
[BaseTool, dict[str, Any], ToolContext],
|
||||
Optional[dict],
|
||||
Union[Awaitable[Optional[dict]], Optional[dict]],
|
||||
]
|
||||
AfterToolCallback: TypeAlias = Callable[
|
||||
[BaseTool, dict[str, Any], ToolContext, dict],
|
||||
Optional[dict],
|
||||
Union[Awaitable[Optional[dict]], Optional[dict]],
|
||||
]
|
||||
|
||||
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
|
||||
|
||||
@@ -66,7 +66,8 @@ class OAuth2Auth(BaseModelWithConfig):
|
||||
redirect_uri: Optional[str] = None
|
||||
auth_response_uri: Optional[str] = None
|
||||
auth_code: Optional[str] = None
|
||||
token: Optional[Dict[str, Any]] = None
|
||||
access_token: Optional[str] = None
|
||||
refresh_token: Optional[str] = None
|
||||
|
||||
|
||||
class ServiceAccountCredential(BaseModelWithConfig):
|
||||
|
||||
@@ -82,7 +82,8 @@ class AuthHandler:
|
||||
or not auth_credential.oauth2
|
||||
or not auth_credential.oauth2.client_id
|
||||
or not auth_credential.oauth2.client_secret
|
||||
or auth_credential.oauth2.token
|
||||
or auth_credential.oauth2.access_token
|
||||
or auth_credential.oauth2.refresh_token
|
||||
):
|
||||
return self.auth_config.exchanged_auth_credential
|
||||
|
||||
@@ -93,7 +94,7 @@ class AuthHandler:
|
||||
redirect_uri=auth_credential.oauth2.redirect_uri,
|
||||
state=auth_credential.oauth2.state,
|
||||
)
|
||||
token = client.fetch_token(
|
||||
tokens = client.fetch_token(
|
||||
token_endpoint,
|
||||
authorization_response=auth_credential.oauth2.auth_response_uri,
|
||||
code=auth_credential.oauth2.auth_code,
|
||||
@@ -102,7 +103,10 @@ class AuthHandler:
|
||||
|
||||
updated_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(token=dict(token)),
|
||||
oauth2=OAuth2Auth(
|
||||
access_token=tokens.get("access_token"),
|
||||
refresh_token=tokens.get("refresh_token"),
|
||||
),
|
||||
)
|
||||
return updated_credential
|
||||
|
||||
|
||||
@@ -29,5 +29,5 @@
|
||||
<style>html{color-scheme:dark}html{--mat-sys-background:light-dark(#fcf9f8, #131314);--mat-sys-error:light-dark(#ba1a1a, #ffb4ab);--mat-sys-error-container:light-dark(#ffdad6, #93000a);--mat-sys-inverse-on-surface:light-dark(#f3f0f0, #313030);--mat-sys-inverse-primary:light-dark(#c1c7cd, #595f65);--mat-sys-inverse-surface:light-dark(#313030, #e5e2e2);--mat-sys-on-background:light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-error:light-dark(#ffffff, #690005);--mat-sys-on-error-container:light-dark(#410002, #ffdad6);--mat-sys-on-primary:light-dark(#ffffff, #2b3136);--mat-sys-on-primary-container:light-dark(#161c21, #dde3e9);--mat-sys-on-primary-fixed:light-dark(#161c21, #161c21);--mat-sys-on-primary-fixed-variant:light-dark(#41474d, #41474d);--mat-sys-on-secondary:light-dark(#ffffff, #003061);--mat-sys-on-secondary-container:light-dark(#001b3c, #d5e3ff);--mat-sys-on-secondary-fixed:light-dark(#001b3c, #001b3c);--mat-sys-on-secondary-fixed-variant:light-dark(#0f4784, #0f4784);--mat-sys-on-surface:light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-surface-variant:light-dark(#44474a, #e1e2e6);--mat-sys-on-tertiary:light-dark(#ffffff, #2b3136);--mat-sys-on-tertiary-container:light-dark(#161c21, #dde3e9);--mat-sys-on-tertiary-fixed:light-dark(#161c21, #161c21);--mat-sys-on-tertiary-fixed-variant:light-dark(#41474d, #41474d);--mat-sys-outline:light-dark(#74777b, #8e9194);--mat-sys-outline-variant:light-dark(#c4c7ca, #44474a);--mat-sys-primary:light-dark(#595f65, #c1c7cd);--mat-sys-primary-container:light-dark(#dde3e9, #41474d);--mat-sys-primary-fixed:light-dark(#dde3e9, #dde3e9);--mat-sys-primary-fixed-dim:light-dark(#c1c7cd, #c1c7cd);--mat-sys-scrim:light-dark(#000000, #000000);--mat-sys-secondary:light-dark(#305f9d, #a7c8ff);--mat-sys-secondary-container:light-dark(#d5e3ff, #0f4784);--mat-sys-secondary-fixed:light-dark(#d5e3ff, #d5e3ff);--mat-sys-secondary-fixed-dim:light-dark(#a7c8ff, #a7c8ff);--mat-sys-shadow:light-dark(#000000, #000000);--mat-sys-surface:light-dark(#fcf9f8, #131314);--mat-sys-surface-bright:light-dark(#fcf9f8, #393939);--mat-sys-surface-container:light-dark(#f0eded, #201f20);--mat-sys-surface-container-high:light-dark(#eae7e7, #2a2a2a);--mat-sys-surface-container-highest:light-dark(#e5e2e2, #393939);--mat-sys-surface-container-low:light-dark(#f6f3f3, #1c1b1c);--mat-sys-surface-container-lowest:light-dark(#ffffff, #0e0e0e);--mat-sys-surface-dim:light-dark(#dcd9d9, #131314);--mat-sys-surface-tint:light-dark(#595f65, #c1c7cd);--mat-sys-surface-variant:light-dark(#e1e2e6, #44474a);--mat-sys-tertiary:light-dark(#595f65, #c1c7cd);--mat-sys-tertiary-container:light-dark(#dde3e9, #41474d);--mat-sys-tertiary-fixed:light-dark(#dde3e9, #dde3e9);--mat-sys-tertiary-fixed-dim:light-dark(#c1c7cd, #c1c7cd);--mat-sys-neutral-variant20:#2d3134;--mat-sys-neutral10:#1c1b1c}html{--mat-sys-level0:0px 0px 0px 0px rgba(0, 0, 0, .2), 0px 0px 0px 0px rgba(0, 0, 0, .14), 0px 0px 0px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level1:0px 2px 1px -1px rgba(0, 0, 0, .2), 0px 1px 1px 0px rgba(0, 0, 0, .14), 0px 1px 3px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level2:0px 3px 3px -2px rgba(0, 0, 0, .2), 0px 3px 4px 0px rgba(0, 0, 0, .14), 0px 1px 8px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level3:0px 3px 5px -1px rgba(0, 0, 0, .2), 0px 6px 10px 0px rgba(0, 0, 0, .14), 0px 1px 18px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level4:0px 5px 5px -3px rgba(0, 0, 0, .2), 0px 8px 10px 1px rgba(0, 0, 0, .14), 0px 3px 14px 2px rgba(0, 0, 0, .12)}html{--mat-sys-level5:0px 7px 8px -4px rgba(0, 0, 0, .2), 0px 12px 17px 2px rgba(0, 0, 0, .14), 0px 5px 22px 4px rgba(0, 0, 0, .12)}html{--mat-sys-corner-extra-large:28px;--mat-sys-corner-extra-large-top:28px 28px 0 0;--mat-sys-corner-extra-small:4px;--mat-sys-corner-extra-small-top:4px 4px 0 0;--mat-sys-corner-full:9999px;--mat-sys-corner-large:16px;--mat-sys-corner-large-end:0 16px 16px 0;--mat-sys-corner-large-start:16px 0 0 16px;--mat-sys-corner-large-top:16px 16px 0 0;--mat-sys-corner-medium:12px;--mat-sys-corner-none:0;--mat-sys-corner-small:8px}html{--mat-sys-dragged-state-layer-opacity:.16;--mat-sys-focus-state-layer-opacity:.12;--mat-sys-hover-state-layer-opacity:.08;--mat-sys-pressed-state-layer-opacity:.12}html{font-family:Google Sans,Helvetica Neue,sans-serif!important}body{height:100vh;margin:0}:root{--mat-sys-primary:black;--mdc-checkbox-selected-icon-color:white;--mat-sys-background:#131314;--mat-tab-header-active-label-text-color:#8AB4F8;--mat-tab-header-active-hover-label-text-color:#8AB4F8;--mat-tab-header-active-focus-label-text-color:#8AB4F8;--mat-tab-header-label-text-weight:500;--mdc-text-button-label-text-color:#89b4f8}:root{--mdc-dialog-container-color:#2b2b2f}:root{--mdc-dialog-subhead-color:white}:root{--mdc-circular-progress-active-indicator-color:#a8c7fa}:root{--mdc-circular-progress-size:80}</style><link rel="stylesheet" href="styles-4VDSPQ37.css" media="print" onload="this.media='all'"><noscript><link rel="stylesheet" href="styles-4VDSPQ37.css"></noscript></head>
|
||||
<body>
|
||||
<app-root></app-root>
|
||||
<script src="polyfills-FFHMD2TL.js" type="module"></script><script src="main-ZBO76GRM.js" type="module"></script></body>
|
||||
<script src="polyfills-FFHMD2TL.js" type="module"></script><script src="main-HWIBUY2R.js" type="module"></script></body>
|
||||
</html>
|
||||
|
||||
+69
-53
File diff suppressed because one or more lines are too long
+55
-48
@@ -39,12 +39,12 @@ class InputFile(BaseModel):
|
||||
|
||||
async def run_input_file(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
root_agent: LlmAgent,
|
||||
artifact_service: BaseArtifactService,
|
||||
session: Session,
|
||||
session_service: BaseSessionService,
|
||||
input_path: str,
|
||||
) -> None:
|
||||
) -> Session:
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
agent=root_agent,
|
||||
@@ -55,9 +55,11 @@ async def run_input_file(
|
||||
input_file = InputFile.model_validate_json(f.read())
|
||||
input_file.state['_time'] = datetime.now()
|
||||
|
||||
session.state = input_file.state
|
||||
session = session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=input_file.state
|
||||
)
|
||||
for query in input_file.queries:
|
||||
click.echo(f'user: {query}')
|
||||
click.echo(f'[user]: {query}')
|
||||
content = types.Content(role='user', parts=[types.Part(text=query)])
|
||||
async for event in runner.run_async(
|
||||
user_id=session.user_id, session_id=session.id, new_message=content
|
||||
@@ -65,23 +67,23 @@ async def run_input_file(
|
||||
if event.content and event.content.parts:
|
||||
if text := ''.join(part.text or '' for part in event.content.parts):
|
||||
click.echo(f'[{event.author}]: {text}')
|
||||
return session
|
||||
|
||||
|
||||
async def run_interactively(
|
||||
app_name: str,
|
||||
root_agent: LlmAgent,
|
||||
artifact_service: BaseArtifactService,
|
||||
session: Session,
|
||||
session_service: BaseSessionService,
|
||||
) -> None:
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
app_name=session.app_name,
|
||||
agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
)
|
||||
while True:
|
||||
query = input('user: ')
|
||||
query = input('[user]: ')
|
||||
if not query or not query.strip():
|
||||
continue
|
||||
if query == 'exit':
|
||||
@@ -100,7 +102,8 @@ async def run_cli(
|
||||
*,
|
||||
agent_parent_dir: str,
|
||||
agent_folder_name: str,
|
||||
json_file_path: Optional[str] = None,
|
||||
input_file: Optional[str] = None,
|
||||
saved_session_file: Optional[str] = None,
|
||||
save_session: bool,
|
||||
) -> None:
|
||||
"""Runs an interactive CLI for a certain agent.
|
||||
@@ -109,8 +112,11 @@ async def run_cli(
|
||||
agent_parent_dir: str, the absolute path of the parent folder of the agent
|
||||
folder.
|
||||
agent_folder_name: str, the name of the agent folder.
|
||||
json_file_path: Optional[str], the absolute path to the json file, either
|
||||
*.input.json or *.session.json.
|
||||
input_file: Optional[str], the absolute path to the json file that contains
|
||||
the initial session state and user queries, exclusive with
|
||||
saved_session_file.
|
||||
saved_session_file: Optional[str], the absolute path to the json file that
|
||||
contains a previously saved session, exclusive with input_file.
|
||||
save_session: bool, whether to save the session on exit.
|
||||
"""
|
||||
if agent_parent_dir not in sys.path:
|
||||
@@ -118,46 +124,50 @@ async def run_cli(
|
||||
|
||||
artifact_service = InMemoryArtifactService()
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
app_name=agent_folder_name, user_id='test_user'
|
||||
)
|
||||
|
||||
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
|
||||
agent_module = importlib.import_module(agent_folder_name)
|
||||
user_id = 'test_user'
|
||||
session = session_service.create_session(
|
||||
app_name=agent_folder_name, user_id=user_id
|
||||
)
|
||||
root_agent = agent_module.agent.root_agent
|
||||
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
|
||||
if json_file_path:
|
||||
if json_file_path.endswith('.input.json'):
|
||||
await run_input_file(
|
||||
app_name=agent_folder_name,
|
||||
root_agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
input_path=json_file_path,
|
||||
)
|
||||
elif json_file_path.endswith('.session.json'):
|
||||
with open(json_file_path, 'r') as f:
|
||||
session = Session.model_validate_json(f.read())
|
||||
for content in session.get_contents():
|
||||
if content.role == 'user':
|
||||
print('user: ', content.parts[0].text)
|
||||
if input_file:
|
||||
session = await run_input_file(
|
||||
app_name=agent_folder_name,
|
||||
user_id=user_id,
|
||||
root_agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
input_path=input_file,
|
||||
)
|
||||
elif saved_session_file:
|
||||
|
||||
loaded_session = None
|
||||
with open(saved_session_file, 'r') as f:
|
||||
loaded_session = Session.model_validate_json(f.read())
|
||||
|
||||
if loaded_session:
|
||||
for event in loaded_session.events:
|
||||
session_service.append_event(session, event)
|
||||
content = event.content
|
||||
if not content or not content.parts or not content.parts[0].text:
|
||||
continue
|
||||
if event.author == 'user':
|
||||
click.echo(f'[user]: {content.parts[0].text}')
|
||||
else:
|
||||
print(content.parts[0].text)
|
||||
await run_interactively(
|
||||
agent_folder_name,
|
||||
root_agent,
|
||||
artifact_service,
|
||||
session,
|
||||
session_service,
|
||||
)
|
||||
else:
|
||||
print(f'Unsupported file type: {json_file_path}')
|
||||
exit(1)
|
||||
else:
|
||||
print(f'Running agent {root_agent.name}, type exit to exit.')
|
||||
click.echo(f'[{event.author}]: {content.parts[0].text}')
|
||||
|
||||
await run_interactively(
|
||||
root_agent,
|
||||
artifact_service,
|
||||
session,
|
||||
session_service,
|
||||
)
|
||||
else:
|
||||
click.echo(f'Running agent {root_agent.name}, type exit to exit.')
|
||||
await run_interactively(
|
||||
agent_folder_name,
|
||||
root_agent,
|
||||
artifact_service,
|
||||
session,
|
||||
@@ -165,11 +175,8 @@ async def run_cli(
|
||||
)
|
||||
|
||||
if save_session:
|
||||
if json_file_path:
|
||||
session_path = json_file_path.replace('.input.json', '.session.json')
|
||||
else:
|
||||
session_id = input('Session ID to save: ')
|
||||
session_path = f'{agent_module_path}/{session_id}.session.json'
|
||||
session_id = input('Session ID to save: ')
|
||||
session_path = f'{agent_module_path}/{session_id}.session.json'
|
||||
|
||||
# Fetch the session again to get all the details.
|
||||
session = session_service.get_session(
|
||||
|
||||
@@ -54,7 +54,7 @@ COPY "agents/{app_name}/" "/app/agents/{app_name}/"
|
||||
|
||||
EXPOSE {port}
|
||||
|
||||
CMD adk {command} --port={port} {trace_to_cloud_option} "/app/agents"
|
||||
CMD adk {command} --port={port} {session_db_option} {trace_to_cloud_option} "/app/agents"
|
||||
"""
|
||||
|
||||
|
||||
@@ -85,6 +85,7 @@ def to_cloud_run(
|
||||
trace_to_cloud: bool,
|
||||
with_ui: bool,
|
||||
verbosity: str,
|
||||
session_db_url: str,
|
||||
):
|
||||
"""Deploys an agent to Google Cloud Run.
|
||||
|
||||
@@ -112,6 +113,7 @@ def to_cloud_run(
|
||||
trace_to_cloud: Whether to enable Cloud Trace.
|
||||
with_ui: Whether to deploy with UI.
|
||||
verbosity: The verbosity level of the CLI.
|
||||
session_db_url: The database URL to connect the session.
|
||||
"""
|
||||
app_name = app_name or os.path.basename(agent_folder)
|
||||
|
||||
@@ -144,6 +146,9 @@ def to_cloud_run(
|
||||
port=port,
|
||||
command='web' if with_ui else 'api_server',
|
||||
install_agent_deps=install_agent_deps,
|
||||
session_db_option=f'--session_db_url={session_db_url}'
|
||||
if session_db_url
|
||||
else '',
|
||||
trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '',
|
||||
)
|
||||
dockerfile_path = os.path.join(temp_folder, 'Dockerfile')
|
||||
|
||||
@@ -256,7 +256,7 @@ def run_evals(
|
||||
)
|
||||
|
||||
if final_eval_status == EvalStatus.PASSED:
|
||||
result = "✅ Passsed"
|
||||
result = "✅ Passed"
|
||||
else:
|
||||
result = "❌ Failed"
|
||||
|
||||
|
||||
@@ -96,6 +96,23 @@ def cli_create_cmd(
|
||||
)
|
||||
|
||||
|
||||
def validate_exclusive(ctx, param, value):
|
||||
# Store the validated parameters in the context
|
||||
if not hasattr(ctx, "exclusive_opts"):
|
||||
ctx.exclusive_opts = {}
|
||||
|
||||
# If this option has a value and we've already seen another exclusive option
|
||||
if value is not None and any(ctx.exclusive_opts.values()):
|
||||
exclusive_opt = next(key for key, val in ctx.exclusive_opts.items() if val)
|
||||
raise click.UsageError(
|
||||
f"Options '{param.name}' and '{exclusive_opt}' cannot be set together."
|
||||
)
|
||||
|
||||
# Record this option's value
|
||||
ctx.exclusive_opts[param.name] = value is not None
|
||||
return value
|
||||
|
||||
|
||||
@main.command("run")
|
||||
@click.option(
|
||||
"--save_session",
|
||||
@@ -105,13 +122,43 @@ def cli_create_cmd(
|
||||
default=False,
|
||||
help="Optional. Whether to save the session to a json file on exit.",
|
||||
)
|
||||
@click.option(
|
||||
"--replay",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=False, file_okay=True, resolve_path=True
|
||||
),
|
||||
help=(
|
||||
"The json file that contains the initial state of the session and user"
|
||||
" queries. A new session will be created using this state. And user"
|
||||
" queries are run againt the newly created session. Users cannot"
|
||||
" continue to interact with the agent."
|
||||
),
|
||||
callback=validate_exclusive,
|
||||
)
|
||||
@click.option(
|
||||
"--resume",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=False, file_okay=True, resolve_path=True
|
||||
),
|
||||
help=(
|
||||
"The json file that contains a previously saved session (by"
|
||||
"--save_session option). The previous session will be re-displayed. And"
|
||||
" user can continue to interact with the agent."
|
||||
),
|
||||
callback=validate_exclusive,
|
||||
)
|
||||
@click.argument(
|
||||
"agent",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
)
|
||||
def cli_run(agent: str, save_session: bool):
|
||||
def cli_run(
|
||||
agent: str,
|
||||
save_session: bool,
|
||||
replay: Optional[str],
|
||||
resume: Optional[str],
|
||||
):
|
||||
"""Runs an interactive CLI for a certain agent.
|
||||
|
||||
AGENT: The path to the agent source code folder.
|
||||
@@ -129,6 +176,8 @@ def cli_run(agent: str, save_session: bool):
|
||||
run_cli(
|
||||
agent_parent_dir=agent_parent_folder,
|
||||
agent_folder_name=agent_folder_name,
|
||||
input_file=replay,
|
||||
saved_session_file=resume,
|
||||
save_session=save_session,
|
||||
)
|
||||
)
|
||||
@@ -245,12 +294,13 @@ def cli_eval(
|
||||
@click.option(
|
||||
"--session_db_url",
|
||||
help=(
|
||||
"Optional. The database URL to store the session.\n\n - Use"
|
||||
" 'agentengine://<agent_engine_resource_id>' to connect to Vertex"
|
||||
" managed session service.\n\n - Use 'sqlite://<path_to_sqlite_file>'"
|
||||
" to connect to a SQLite DB.\n\n - See"
|
||||
" https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls"
|
||||
" for more details on supported DB URLs."
|
||||
"""Optional. The database URL to store the session.
|
||||
|
||||
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
|
||||
|
||||
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
|
||||
|
||||
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
@@ -366,12 +416,13 @@ def cli_web(
|
||||
@click.option(
|
||||
"--session_db_url",
|
||||
help=(
|
||||
"Optional. The database URL to store the session.\n\n - Use"
|
||||
" 'agentengine://<agent_engine_resource_id>' to connect to Vertex"
|
||||
" managed session service.\n\n - Use 'sqlite://<path_to_sqlite_file>'"
|
||||
" to connect to a SQLite DB.\n\n - See"
|
||||
" https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls"
|
||||
" for more details on supported DB URLs."
|
||||
"""Optional. The database URL to store the session.
|
||||
|
||||
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
|
||||
|
||||
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
|
||||
|
||||
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
@@ -541,6 +592,18 @@ def cli_api_server(
|
||||
default="WARNING",
|
||||
help="Optional. Override the default verbosity level.",
|
||||
)
|
||||
@click.option(
|
||||
"--session_db_url",
|
||||
help=(
|
||||
"""Optional. The database URL to store the session.
|
||||
|
||||
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
|
||||
|
||||
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
|
||||
|
||||
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
|
||||
),
|
||||
)
|
||||
@click.argument(
|
||||
"agent",
|
||||
type=click.Path(
|
||||
@@ -558,6 +621,7 @@ def cli_deploy_cloud_run(
|
||||
trace_to_cloud: bool,
|
||||
with_ui: bool,
|
||||
verbosity: str,
|
||||
session_db_url: str,
|
||||
):
|
||||
"""Deploys an agent to Cloud Run.
|
||||
|
||||
@@ -579,6 +643,7 @@ def cli_deploy_cloud_run(
|
||||
trace_to_cloud=trace_to_cloud,
|
||||
with_ui=with_ui,
|
||||
verbosity=verbosity,
|
||||
session_db_url=session_db_url,
|
||||
)
|
||||
except Exception as e:
|
||||
click.secho(f"Deploy failed: {e}", fg="red", err=True)
|
||||
|
||||
@@ -756,6 +756,12 @@ def get_fast_api_app(
|
||||
except Exception as e:
|
||||
logger.exception("Error during live websocket communication: %s", e)
|
||||
traceback.print_exc()
|
||||
WEBSOCKET_INTERNAL_ERROR_CODE = 1011
|
||||
WEBSOCKET_MAX_BYTES_FOR_REASON = 123
|
||||
await websocket.close(
|
||||
code=WEBSOCKET_INTERNAL_ERROR_CODE,
|
||||
reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON],
|
||||
)
|
||||
finally:
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
@@ -55,7 +55,7 @@ def load_json(file_path: str) -> Union[Dict, List]:
|
||||
|
||||
|
||||
class AgentEvaluator:
|
||||
"""An evaluator for Agents, mainly intented for helping with test cases."""
|
||||
"""An evaluator for Agents, mainly intended for helping with test cases."""
|
||||
|
||||
@staticmethod
|
||||
def find_config_for_test_file(test_file: str):
|
||||
@@ -91,7 +91,7 @@ class AgentEvaluator:
|
||||
look for 'root_agent' in the loaded module.
|
||||
eval_dataset: The eval data set. This can be either a string representing
|
||||
full path to the file containing eval dataset, or a directory that is
|
||||
recusively explored for all files that have a `.test.json` suffix.
|
||||
recursively explored for all files that have a `.test.json` suffix.
|
||||
num_runs: Number of times all entries in the eval dataset should be
|
||||
assessed.
|
||||
agent_name: The name of the agent.
|
||||
|
||||
@@ -35,7 +35,7 @@ class ResponseEvaluator:
|
||||
Args:
|
||||
raw_eval_dataset: The dataset that will be evaluated.
|
||||
evaluation_criteria: The evaluation criteria to be used. This method
|
||||
support two criterias, `response_evaluation_score` and
|
||||
support two criteria, `response_evaluation_score` and
|
||||
`response_match_score`.
|
||||
print_detailed_results: Prints detailed results on the console. This is
|
||||
usually helpful during debugging.
|
||||
@@ -56,7 +56,7 @@ class ResponseEvaluator:
|
||||
Value range: [0, 5], where 0 means that the agent's response is not
|
||||
coherent, while 5 means it is . High values are good.
|
||||
A note on raw_eval_dataset:
|
||||
The dataset should be a list session, where each sesssion is represented
|
||||
The dataset should be a list session, where each session is represented
|
||||
as a list of interaction that need evaluation. Each evaluation is
|
||||
represented as a dictionary that is expected to have values for the
|
||||
following keys:
|
||||
|
||||
@@ -31,10 +31,9 @@ class TrajectoryEvaluator:
|
||||
):
|
||||
r"""Returns the mean tool use accuracy of the eval dataset.
|
||||
|
||||
Tool use accuracy is calculated by comparing the expected and actuall tool
|
||||
use trajectories. An exact match scores a 1, 0 otherwise. The final number
|
||||
is an
|
||||
average of these individual scores.
|
||||
Tool use accuracy is calculated by comparing the expected and the actual
|
||||
tool use trajectories. An exact match scores a 1, 0 otherwise. The final
|
||||
number is an average of these individual scores.
|
||||
|
||||
Value range: [0, 1], where 0 is means none of the too use entries aligned,
|
||||
and 1 would mean all of them aligned. Higher value is good.
|
||||
@@ -45,7 +44,7 @@ class TrajectoryEvaluator:
|
||||
usually helpful during debugging.
|
||||
|
||||
A note on eval_dataset:
|
||||
The dataset should be a list session, where each sesssion is represented
|
||||
The dataset should be a list session, where each session is represented
|
||||
as a list of interaction that need evaluation. Each evaluation is
|
||||
represented as a dictionary that is expected to have values for the
|
||||
following keys:
|
||||
|
||||
@@ -48,8 +48,13 @@ class EventActions(BaseModel):
|
||||
"""The agent is escalating to a higher level agent."""
|
||||
|
||||
requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict)
|
||||
"""Will only be set by a tool response indicating tool request euc.
|
||||
dict key is the function call id since one function call response (from model)
|
||||
could correspond to multiple function calls.
|
||||
dict value is the required auth config.
|
||||
"""Authentication configurations requested by tool responses.
|
||||
|
||||
This field will only be set by a tool response event indicating tool request
|
||||
auth credential.
|
||||
- Keys: The function call id. Since one function response event could contain
|
||||
multiple function responses that correspond to multiple function calls. Each
|
||||
function call could request different auth configs. This id is used to
|
||||
identify the function call.
|
||||
- Values: The requested auth config.
|
||||
"""
|
||||
|
||||
@@ -94,7 +94,7 @@ can answer it.
|
||||
|
||||
If another agent is better for answering the question according to its
|
||||
description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the
|
||||
question to that agent. When transfering, do not generate any text other than
|
||||
question to that agent. When transferring, do not generate any text other than
|
||||
the function call.
|
||||
"""
|
||||
|
||||
|
||||
@@ -115,7 +115,7 @@ class BaseLlmFlow(ABC):
|
||||
yield event
|
||||
# send back the function response
|
||||
if event.get_function_responses():
|
||||
logger.debug('Sending back last function resonse event: %s', event)
|
||||
logger.debug('Sending back last function response event: %s', event)
|
||||
invocation_context.live_request_queue.send_content(event.content)
|
||||
if (
|
||||
event.content
|
||||
|
||||
@@ -111,7 +111,7 @@ def _rearrange_events_for_latest_function_response(
|
||||
"""Rearrange the events for the latest function_response.
|
||||
|
||||
If the latest function_response is for an async function_call, all events
|
||||
bewteen the initial function_call and the latest function_response will be
|
||||
between the initial function_call and the latest function_response will be
|
||||
removed.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -151,28 +151,33 @@ async def handle_function_calls_async(
|
||||
# do not use "args" as the variable name, because it is a reserved keyword
|
||||
# in python debugger.
|
||||
function_args = function_call.args or {}
|
||||
function_response = None
|
||||
# Calls the tool if before_tool_callback does not exist or returns None.
|
||||
function_response: Optional[dict] = None
|
||||
|
||||
# before_tool_callback (sync or async)
|
||||
if agent.before_tool_callback:
|
||||
function_response = agent.before_tool_callback(
|
||||
tool=tool, args=function_args, tool_context=tool_context
|
||||
)
|
||||
if inspect.isawaitable(function_response):
|
||||
function_response = await function_response
|
||||
|
||||
if not function_response:
|
||||
function_response = await __call_tool_async(
|
||||
tool, args=function_args, tool_context=tool_context
|
||||
)
|
||||
|
||||
# Calls after_tool_callback if it exists.
|
||||
# after_tool_callback (sync or async)
|
||||
if agent.after_tool_callback:
|
||||
new_response = agent.after_tool_callback(
|
||||
altered_function_response = agent.after_tool_callback(
|
||||
tool=tool,
|
||||
args=function_args,
|
||||
tool_context=tool_context,
|
||||
tool_response=function_response,
|
||||
)
|
||||
if new_response:
|
||||
function_response = new_response
|
||||
if inspect.isawaitable(altered_function_response):
|
||||
altered_function_response = await altered_function_response
|
||||
if altered_function_response is not None:
|
||||
function_response = altered_function_response
|
||||
|
||||
if tool.is_long_running:
|
||||
# Allow long running function to return None to not provide function response.
|
||||
@@ -223,11 +228,17 @@ async def handle_function_calls_live(
|
||||
# in python debugger.
|
||||
function_args = function_call.args or {}
|
||||
function_response = None
|
||||
# Calls the tool if before_tool_callback does not exist or returns None.
|
||||
# # Calls the tool if before_tool_callback does not exist or returns None.
|
||||
# if agent.before_tool_callback:
|
||||
# function_response = agent.before_tool_callback(
|
||||
# tool, function_args, tool_context
|
||||
# )
|
||||
if agent.before_tool_callback:
|
||||
function_response = agent.before_tool_callback(
|
||||
tool, function_args, tool_context
|
||||
tool=tool, args=function_args, tool_context=tool_context
|
||||
)
|
||||
if inspect.isawaitable(function_response):
|
||||
function_response = await function_response
|
||||
|
||||
if not function_response:
|
||||
function_response = await _process_function_live_helper(
|
||||
@@ -235,15 +246,26 @@ async def handle_function_calls_live(
|
||||
)
|
||||
|
||||
# Calls after_tool_callback if it exists.
|
||||
# if agent.after_tool_callback:
|
||||
# new_response = agent.after_tool_callback(
|
||||
# tool,
|
||||
# function_args,
|
||||
# tool_context,
|
||||
# function_response,
|
||||
# )
|
||||
# if new_response:
|
||||
# function_response = new_response
|
||||
if agent.after_tool_callback:
|
||||
new_response = agent.after_tool_callback(
|
||||
tool,
|
||||
function_args,
|
||||
tool_context,
|
||||
function_response,
|
||||
altered_function_response = agent.after_tool_callback(
|
||||
tool=tool,
|
||||
args=function_args,
|
||||
tool_context=tool_context,
|
||||
tool_response=function_response,
|
||||
)
|
||||
if new_response:
|
||||
function_response = new_response
|
||||
if inspect.isawaitable(altered_function_response):
|
||||
altered_function_response = await altered_function_response
|
||||
if altered_function_response is not None:
|
||||
function_response = altered_function_response
|
||||
|
||||
if tool.is_long_running:
|
||||
# Allow async function to return None to not provide function response.
|
||||
@@ -310,9 +332,7 @@ async def _process_function_live_helper(
|
||||
function_response = {
|
||||
'status': f'No active streaming function named {function_name} found'
|
||||
}
|
||||
elif inspect.isasyncgenfunction(tool.func):
|
||||
print('is async')
|
||||
|
||||
elif hasattr(tool, "func") and inspect.isasyncgenfunction(tool.func):
|
||||
# for streaming tool use case
|
||||
# we require the function to be a async generator function
|
||||
async def run_tool_and_update_queue(tool, function_args, tool_context):
|
||||
|
||||
@@ -52,7 +52,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
# Appends global instructions if set.
|
||||
if (
|
||||
isinstance(root_agent, LlmAgent) and root_agent.global_instruction
|
||||
): # not emtpy str
|
||||
): # not empty str
|
||||
raw_si = root_agent.canonical_global_instruction(
|
||||
ReadonlyContext(invocation_context)
|
||||
)
|
||||
@@ -60,7 +60,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
llm_request.append_instructions([si])
|
||||
|
||||
# Appends agent instructions if set.
|
||||
if agent.instruction: # not emtpy str
|
||||
if agent.instruction: # not empty str
|
||||
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
|
||||
si = _populate_values(raw_si, invocation_context)
|
||||
llm_request.append_instructions([si])
|
||||
|
||||
@@ -152,7 +152,7 @@ class GeminiLlmConnection(BaseLlmConnection):
|
||||
):
|
||||
# TODO: Right now, we just support output_transcription without
|
||||
# changing interface and data protocol. Later, we can consider to
|
||||
# support output_transcription as a separete field in LlmResponse.
|
||||
# support output_transcription as a separate field in LlmResponse.
|
||||
|
||||
# Transcription is always considered as partial event
|
||||
# We rely on other control signals to determine when to yield the
|
||||
@@ -179,7 +179,7 @@ class GeminiLlmConnection(BaseLlmConnection):
|
||||
# in case of empty content or parts, we sill surface it
|
||||
# in case it's an interrupted message, we merge the previous partial
|
||||
# text. Other we don't merge. because content can be none when model
|
||||
# safty threshold is triggered
|
||||
# safety threshold is triggered
|
||||
if message.server_content.interrupted and text:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
@@ -37,6 +37,7 @@ class LlmResponse(BaseModel):
|
||||
error_message: Error message if the response is an error.
|
||||
interrupted: Flag indicating that LLM was interrupted when generating the
|
||||
content. Usually it's due to user interruption during a bidi streaming.
|
||||
custom_metadata: The custom metadata of the LlmResponse.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
@@ -71,6 +72,14 @@ class LlmResponse(BaseModel):
|
||||
Usually it's due to user interruption during a bidi streaming.
|
||||
"""
|
||||
|
||||
custom_metadata: Optional[dict[str, Any]] = None
|
||||
"""The custom metadata of the LlmResponse.
|
||||
|
||||
An optional key-value pair to label an LlmResponse.
|
||||
|
||||
NOTE: the entire dict must be JSON serializable.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
generate_content_response: types.GenerateContentResponse,
|
||||
|
||||
@@ -56,6 +56,7 @@ class BuiltInPlanner(BasePlanner):
|
||||
llm_request: The LLM request to apply the thinking config to.
|
||||
"""
|
||||
if self.thinking_config:
|
||||
llm_request.config = llm_request.config or types.GenerateContentConfig()
|
||||
llm_request.config.thinking_config = self.thinking_config
|
||||
|
||||
@override
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Utility functions for session service."""
|
||||
|
||||
import base64
|
||||
from typing import Any, Optional
|
||||
|
||||
from google.genai import types
|
||||
|
||||
|
||||
def encode_content(content: types.Content):
|
||||
"""Encodes a content object to a JSON dictionary."""
|
||||
encoded_content = content.model_dump(exclude_none=True)
|
||||
for p in encoded_content["parts"]:
|
||||
if "inline_data" in p:
|
||||
p["inline_data"]["data"] = base64.b64encode(
|
||||
p["inline_data"]["data"]
|
||||
).decode("utf-8")
|
||||
return encoded_content
|
||||
|
||||
|
||||
def decode_content(
|
||||
content: Optional[dict[str, Any]],
|
||||
) -> Optional[types.Content]:
|
||||
"""Decodes a content object from a JSON dictionary."""
|
||||
if not content:
|
||||
return None
|
||||
for p in content["parts"]:
|
||||
if "inline_data" in p:
|
||||
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
|
||||
return types.Content.model_validate(content)
|
||||
@@ -11,14 +11,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import copy
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Boolean
|
||||
@@ -27,6 +24,7 @@ from sqlalchemy import Dialect
|
||||
from sqlalchemy import ForeignKeyConstraint
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.dialects import mysql
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
@@ -48,6 +46,7 @@ from typing_extensions import override
|
||||
from tzlocal import get_localzone
|
||||
|
||||
from ..events.event import Event
|
||||
from . import _session_util
|
||||
from .base_session_service import BaseSessionService
|
||||
from .base_session_service import GetSessionConfig
|
||||
from .base_session_service import ListEventsResponse
|
||||
@@ -58,6 +57,9 @@ from .state import State
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MAX_KEY_LENGTH = 128
|
||||
DEFAULT_MAX_VARCHAR_LENGTH = 256
|
||||
|
||||
|
||||
class DynamicJSON(TypeDecorator):
|
||||
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
|
||||
@@ -70,15 +72,16 @@ class DynamicJSON(TypeDecorator):
|
||||
def load_dialect_impl(self, dialect: Dialect):
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(postgresql.JSONB)
|
||||
else:
|
||||
return dialect.type_descriptor(Text) # Default to Text for other dialects
|
||||
if dialect.name == "mysql":
|
||||
# Use LONGTEXT for MySQL to address the data too long issue
|
||||
return dialect.type_descriptor(mysql.LONGTEXT)
|
||||
return dialect.type_descriptor(Text) # Default to Text for other dialects
|
||||
|
||||
def process_bind_param(self, value, dialect: Dialect):
|
||||
if value is not None:
|
||||
if dialect.name == "postgresql":
|
||||
return value # JSONB handles dict directly
|
||||
else:
|
||||
return json.dumps(value) # Serialize to JSON string for TEXT
|
||||
return json.dumps(value) # Serialize to JSON string for TEXT
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect: Dialect):
|
||||
@@ -92,17 +95,25 @@ class DynamicJSON(TypeDecorator):
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for database tables."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StorageSession(Base):
|
||||
"""Represents a session stored in the database."""
|
||||
|
||||
__tablename__ = "sessions"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
id: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
String(DEFAULT_MAX_KEY_LENGTH),
|
||||
primary_key=True,
|
||||
default=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
@@ -125,18 +136,29 @@ class StorageSession(Base):
|
||||
|
||||
class StorageEvent(Base):
|
||||
"""Represents an event stored in the database."""
|
||||
|
||||
__tablename__ = "events"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
session_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
session_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
|
||||
invocation_id: Mapped[str] = mapped_column(String)
|
||||
author: Mapped[str] = mapped_column(String)
|
||||
branch: Mapped[str] = mapped_column(String, nullable=True)
|
||||
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
||||
author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
||||
branch: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
|
||||
)
|
||||
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
||||
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON)
|
||||
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
|
||||
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
|
||||
|
||||
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
|
||||
@@ -147,8 +169,10 @@ class StorageEvent(Base):
|
||||
)
|
||||
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
error_code: Mapped[str] = mapped_column(String, nullable=True)
|
||||
error_message: Mapped[str] = mapped_column(String, nullable=True)
|
||||
error_code: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
|
||||
)
|
||||
error_message: Mapped[str] = mapped_column(String(1024), nullable=True)
|
||||
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
storage_session: Mapped[StorageSession] = relationship(
|
||||
@@ -182,9 +206,12 @@ class StorageEvent(Base):
|
||||
|
||||
class StorageAppState(Base):
|
||||
"""Represents an app state stored in the database."""
|
||||
|
||||
__tablename__ = "app_states"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
@@ -192,13 +219,17 @@ class StorageAppState(Base):
|
||||
DateTime(), default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class StorageUserState(Base):
|
||||
"""Represents a user state stored in the database."""
|
||||
|
||||
__tablename__ = "user_states"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
@@ -217,7 +248,7 @@ class DatabaseSessionService(BaseSessionService):
|
||||
"""
|
||||
# 1. Create DB engine for db connection
|
||||
# 2. Create all tables based on schema
|
||||
# 3. Initialize all properies
|
||||
# 3. Initialize all properties
|
||||
|
||||
try:
|
||||
db_engine = create_engine(db_url)
|
||||
@@ -353,6 +384,7 @@ class DatabaseSessionService(BaseSessionService):
|
||||
else True
|
||||
)
|
||||
.limit(config.num_recent_events if config else None)
|
||||
.order_by(StorageEvent.timestamp.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -383,7 +415,7 @@ class DatabaseSessionService(BaseSessionService):
|
||||
author=e.author,
|
||||
branch=e.branch,
|
||||
invocation_id=e.invocation_id,
|
||||
content=_decode_content(e.content),
|
||||
content=_session_util.decode_content(e.content),
|
||||
actions=e.actions,
|
||||
timestamp=e.timestamp.timestamp(),
|
||||
long_running_tool_ids=e.long_running_tool_ids,
|
||||
@@ -506,15 +538,7 @@ class DatabaseSessionService(BaseSessionService):
|
||||
interrupted=event.interrupted,
|
||||
)
|
||||
if event.content:
|
||||
encoded_content = event.content.model_dump(exclude_none=True)
|
||||
# Workaround for multimodal Content throwing JSON not serializable
|
||||
# error with SQLAlchemy.
|
||||
for p in encoded_content["parts"]:
|
||||
if "inline_data" in p:
|
||||
p["inline_data"]["data"] = (
|
||||
base64.b64encode(p["inline_data"]["data"]).decode("utf-8"),
|
||||
)
|
||||
storage_event.content = encoded_content
|
||||
storage_event.content = _session_util.encode_content(event.content)
|
||||
|
||||
sessionFactory.add(storage_event)
|
||||
|
||||
@@ -574,10 +598,3 @@ def _merge_state(app_state, user_state, session_state):
|
||||
for key in user_state.keys():
|
||||
merged_state[State.USER_PREFIX + key] = user_state[key]
|
||||
return merged_state
|
||||
|
||||
|
||||
def _decode_content(content: dict[str, Any]) -> dict[str, Any]:
|
||||
for p in content["parts"]:
|
||||
if "inline_data" in p:
|
||||
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
|
||||
return content
|
||||
|
||||
@@ -26,7 +26,7 @@ class State:
|
||||
"""
|
||||
Args:
|
||||
value: The current value of the state dict.
|
||||
delta: The delta change to the current value that hasn't been commited.
|
||||
delta: The delta change to the current value that hasn't been committed.
|
||||
"""
|
||||
self._value = value
|
||||
self._delta = delta
|
||||
|
||||
@@ -14,21 +14,23 @@
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from dateutil import parser
|
||||
from google import genai
|
||||
from typing_extensions import override
|
||||
|
||||
from ..events.event import Event
|
||||
from ..events.event_actions import EventActions
|
||||
from . import _session_util
|
||||
from .base_session_service import BaseSessionService
|
||||
from .base_session_service import GetSessionConfig
|
||||
from .base_session_service import ListEventsResponse
|
||||
from .base_session_service import ListSessionsResponse
|
||||
from .session import Session
|
||||
|
||||
|
||||
isoparse = parser.isoparse
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -289,7 +291,7 @@ def _convert_event_to_json(event: Event):
|
||||
}
|
||||
event_json['actions'] = actions_json
|
||||
if event.content:
|
||||
event_json['content'] = event.content.model_dump(exclude_none=True)
|
||||
event_json['content'] = _session_util.encode_content(event.content)
|
||||
if event.error_code:
|
||||
event_json['error_code'] = event.error_code
|
||||
if event.error_message:
|
||||
@@ -316,7 +318,7 @@ def _from_api_event(api_event: dict) -> Event:
|
||||
invocation_id=api_event['invocationId'],
|
||||
author=api_event['author'],
|
||||
actions=event_actions,
|
||||
content=api_event.get('content', None),
|
||||
content=_session_util.decode_content(api_event.get('content', None)),
|
||||
timestamp=isoparse(api_event['timestamp']).timestamp(),
|
||||
error_code=api_event.get('errorCode', None),
|
||||
error_message=api_event.get('errorMessage', None),
|
||||
|
||||
@@ -45,10 +45,9 @@ class AgentTool(BaseTool):
|
||||
skip_summarization: Whether to skip summarization of the agent output.
|
||||
"""
|
||||
|
||||
def __init__(self, agent: BaseAgent):
|
||||
def __init__(self, agent: BaseAgent, skip_summarization: bool = False):
|
||||
self.agent = agent
|
||||
self.skip_summarization: bool = False
|
||||
"""Whether to skip summarization of the agent output."""
|
||||
self.skip_summarization: bool = skip_summarization
|
||||
|
||||
super().__init__(name=agent.name, description=agent.description)
|
||||
|
||||
|
||||
@@ -13,7 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .application_integration_toolset import ApplicationIntegrationToolset
|
||||
from .integration_connector_tool import IntegrationConnectorTool
|
||||
|
||||
__all__ = [
|
||||
'ApplicationIntegrationToolset',
|
||||
'IntegrationConnectorTool',
|
||||
]
|
||||
|
||||
+48
-26
@@ -12,21 +12,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi.openapi.models import HTTPBearer
|
||||
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
|
||||
from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
from ...auth.auth_credential import AuthCredentialTypes
|
||||
from ...auth.auth_credential import ServiceAccount
|
||||
from ...auth.auth_credential import ServiceAccountCredential
|
||||
from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential
|
||||
from ..openapi_tool.openapi_spec_parser.openapi_spec_parser import OpenApiSpecParser
|
||||
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
||||
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
from .clients.connections_client import ConnectionsClient
|
||||
from .clients.integration_client import IntegrationClient
|
||||
from .integration_connector_tool import IntegrationConnectorTool
|
||||
|
||||
|
||||
# TODO(cheliu): Apply a common toolset interface
|
||||
@@ -168,6 +168,7 @@ class ApplicationIntegrationToolset:
|
||||
actions,
|
||||
service_account_json,
|
||||
)
|
||||
connection_details = {}
|
||||
if integration and trigger:
|
||||
spec = integration_client.get_openapi_spec_for_integration()
|
||||
elif connection and (entity_operations or actions):
|
||||
@@ -175,16 +176,6 @@ class ApplicationIntegrationToolset:
|
||||
project, location, connection, service_account_json
|
||||
)
|
||||
connection_details = connections_client.get_connection_details()
|
||||
tool_instructions += (
|
||||
"ALWAYS use serviceName = "
|
||||
+ connection_details["serviceName"]
|
||||
+ ", host = "
|
||||
+ connection_details["host"]
|
||||
+ " and the connection name = "
|
||||
+ f"projects/{project}/locations/{location}/connections/{connection} when"
|
||||
" using this tool"
|
||||
+ ". DONOT ask the user for these values as you already have those."
|
||||
)
|
||||
spec = integration_client.get_openapi_spec_for_connection(
|
||||
tool_name,
|
||||
tool_instructions,
|
||||
@@ -194,9 +185,9 @@ class ApplicationIntegrationToolset:
|
||||
"Either (integration and trigger) or (connection and"
|
||||
" (entity_operations or actions)) should be provided."
|
||||
)
|
||||
self._parse_spec_to_tools(spec)
|
||||
self._parse_spec_to_tools(spec, connection_details)
|
||||
|
||||
def _parse_spec_to_tools(self, spec_dict):
|
||||
def _parse_spec_to_tools(self, spec_dict, connection_details):
|
||||
"""Parses the spec dict to a list of RestApiTool."""
|
||||
if self.service_account_json:
|
||||
sa_credential = ServiceAccountCredential.model_validate_json(
|
||||
@@ -218,12 +209,43 @@ class ApplicationIntegrationToolset:
|
||||
),
|
||||
)
|
||||
auth_scheme = HTTPBearer(bearerFormat="JWT")
|
||||
tools = OpenAPIToolset(
|
||||
spec_dict=spec_dict,
|
||||
auth_credential=auth_credential,
|
||||
auth_scheme=auth_scheme,
|
||||
).get_tools()
|
||||
for tool in tools:
|
||||
|
||||
if self.integration and self.trigger:
|
||||
tools = OpenAPIToolset(
|
||||
spec_dict=spec_dict,
|
||||
auth_credential=auth_credential,
|
||||
auth_scheme=auth_scheme,
|
||||
).get_tools()
|
||||
for tool in tools:
|
||||
self.generated_tools[tool.name] = tool
|
||||
return
|
||||
|
||||
operations = OpenApiSpecParser().parse(spec_dict)
|
||||
|
||||
for open_api_operation in operations:
|
||||
operation = getattr(open_api_operation.operation, "x-operation")
|
||||
entity = None
|
||||
action = None
|
||||
if hasattr(open_api_operation.operation, "x-entity"):
|
||||
entity = getattr(open_api_operation.operation, "x-entity")
|
||||
elif hasattr(open_api_operation.operation, "x-action"):
|
||||
action = getattr(open_api_operation.operation, "x-action")
|
||||
rest_api_tool = RestApiTool.from_parsed_operation(open_api_operation)
|
||||
if auth_scheme:
|
||||
rest_api_tool.configure_auth_scheme(auth_scheme)
|
||||
if auth_credential:
|
||||
rest_api_tool.configure_auth_credential(auth_credential)
|
||||
tool = IntegrationConnectorTool(
|
||||
name=rest_api_tool.name,
|
||||
description=rest_api_tool.description,
|
||||
connection_name=connection_details["name"],
|
||||
connection_host=connection_details["host"],
|
||||
connection_service_name=connection_details["serviceName"],
|
||||
entity=entity,
|
||||
action=action,
|
||||
operation=operation,
|
||||
rest_api_tool=rest_api_tool,
|
||||
)
|
||||
self.generated_tools[tool.name] = tool
|
||||
|
||||
def get_tools(self) -> List[RestApiTool]:
|
||||
|
||||
@@ -68,12 +68,14 @@ class ConnectionsClient:
|
||||
response = self._execute_api_call(url)
|
||||
|
||||
connection_data = response.json()
|
||||
connection_name = connection_data.get("name", "")
|
||||
service_name = connection_data.get("serviceDirectory", "")
|
||||
host = connection_data.get("host", "")
|
||||
if host:
|
||||
service_name = connection_data.get("tlsServiceDirectory", "")
|
||||
auth_override_enabled = connection_data.get("authOverrideEnabled", False)
|
||||
return {
|
||||
"name": connection_name,
|
||||
"serviceName": service_name,
|
||||
"host": host,
|
||||
"authOverrideEnabled": auth_override_enabled,
|
||||
@@ -291,13 +293,9 @@ class ConnectionsClient:
|
||||
tool_name: str = "",
|
||||
tool_instructions: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
description = (
|
||||
f"Use this tool with" f' action = "{action}" and'
|
||||
) + f' operation = "{operation}" only. Dont ask these values from user.'
|
||||
description = f"Use this tool to execute {action}"
|
||||
if operation == "EXECUTE_QUERY":
|
||||
description = (
|
||||
(f"Use this tool with" f' action = "{action}" and')
|
||||
+ f' operation = "{operation}" only. Dont ask these values from user.'
|
||||
description += (
|
||||
" Use pageSize = 50 and timeout = 120 until user specifies a"
|
||||
" different value otherwise. If user provides a query in natural"
|
||||
" language, convert it to SQL query and then execute it using the"
|
||||
@@ -308,6 +306,8 @@ class ConnectionsClient:
|
||||
"summary": f"{action_display_name}",
|
||||
"description": f"{description} {tool_instructions}",
|
||||
"operationId": f"{tool_name}_{action_display_name}",
|
||||
"x-action": f"{action}",
|
||||
"x-operation": f"{operation}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
@@ -347,16 +347,12 @@ class ConnectionsClient:
|
||||
"post": {
|
||||
"summary": f"List {entity}",
|
||||
"description": (
|
||||
f"Returns all entities of type {entity}. Use this tool with"
|
||||
+ f' entity = "{entity}" and'
|
||||
+ ' operation = "LIST_ENTITIES" only. Dont ask these values'
|
||||
" from"
|
||||
+ ' user. Always use ""'
|
||||
+ ' as filter clause and ""'
|
||||
+ " as page token and 50 as page size until user specifies a"
|
||||
" different value otherwise. Use single quotes for strings in"
|
||||
f" filter clause. {tool_instructions}"
|
||||
f"""Returns the list of {entity} data. If the page token was available in the response, let users know there are more records available. Ask if the user wants to fetch the next page of results. When passing filter use the
|
||||
following format: `field_name1='value1' AND field_name2='value2'
|
||||
`. {tool_instructions}"""
|
||||
),
|
||||
"x-operation": "LIST_ENTITIES",
|
||||
"x-entity": f"{entity}",
|
||||
"operationId": f"{tool_name}_list_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
@@ -401,14 +397,11 @@ class ConnectionsClient:
|
||||
"post": {
|
||||
"summary": f"Get {entity}",
|
||||
"description": (
|
||||
(
|
||||
f"Returns the details of the {entity}. Use this tool with"
|
||||
f' entity = "{entity}" and'
|
||||
)
|
||||
+ ' operation = "GET_ENTITY" only. Dont ask these values from'
|
||||
f" user. {tool_instructions}"
|
||||
f"Returns the details of the {entity}. {tool_instructions}"
|
||||
),
|
||||
"operationId": f"{tool_name}_get_{entity}",
|
||||
"x-operation": "GET_ENTITY",
|
||||
"x-entity": f"{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
@@ -445,17 +438,10 @@ class ConnectionsClient:
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"Create {entity}",
|
||||
"description": (
|
||||
(
|
||||
f"Creates a new entity of type {entity}. Use this tool with"
|
||||
f' entity = "{entity}" and'
|
||||
)
|
||||
+ ' operation = "CREATE_ENTITY" only. Dont ask these values'
|
||||
" from"
|
||||
+ " user. Follow the schema of the entity provided in the"
|
||||
f" instructions to create {entity}. {tool_instructions}"
|
||||
),
|
||||
"summary": f"Creates a new {entity}",
|
||||
"description": f"Creates a new {entity}. {tool_instructions}",
|
||||
"x-operation": "CREATE_ENTITY",
|
||||
"x-entity": f"{entity}",
|
||||
"operationId": f"{tool_name}_create_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
@@ -491,18 +477,10 @@ class ConnectionsClient:
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"Update {entity}",
|
||||
"description": (
|
||||
(
|
||||
f"Updates an entity of type {entity}. Use this tool with"
|
||||
f' entity = "{entity}" and'
|
||||
)
|
||||
+ ' operation = "UPDATE_ENTITY" only. Dont ask these values'
|
||||
" from"
|
||||
+ " user. Use entityId to uniquely identify the entity to"
|
||||
" update. Follow the schema of the entity provided in the"
|
||||
f" instructions to update {entity}. {tool_instructions}"
|
||||
),
|
||||
"summary": f"Updates the {entity}",
|
||||
"description": f"Updates the {entity}. {tool_instructions}",
|
||||
"x-operation": "UPDATE_ENTITY",
|
||||
"x-entity": f"{entity}",
|
||||
"operationId": f"{tool_name}_update_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
@@ -538,16 +516,10 @@ class ConnectionsClient:
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"Delete {entity}",
|
||||
"description": (
|
||||
(
|
||||
f"Deletes an entity of type {entity}. Use this tool with"
|
||||
f' entity = "{entity}" and'
|
||||
)
|
||||
+ ' operation = "DELETE_ENTITY" only. Dont ask these values'
|
||||
" from"
|
||||
f" user. {tool_instructions}"
|
||||
),
|
||||
"summary": f"Delete the {entity}",
|
||||
"description": f"Deletes the {entity}. {tool_instructions}",
|
||||
"x-operation": "DELETE_ENTITY",
|
||||
"x-entity": f"{entity}",
|
||||
"operationId": f"{tool_name}_delete_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from typing_extensions import override
|
||||
|
||||
from .. import BaseTool
|
||||
from ..tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IntegrationConnectorTool(BaseTool):
|
||||
"""A tool that wraps a RestApiTool to interact with a specific Application Integration endpoint.
|
||||
|
||||
This tool adds Application Integration specific context like connection
|
||||
details, entity, operation, and action to the underlying REST API call
|
||||
handled by RestApiTool. It prepares the arguments and then delegates the
|
||||
actual API call execution to the contained RestApiTool instance.
|
||||
|
||||
* Generates request params and body
|
||||
* Attaches auth credentials to API call.
|
||||
|
||||
Example:
|
||||
```
|
||||
# Each API operation in the spec will be turned into its own tool
|
||||
# Name of the tool is the operationId of that operation, in snake case
|
||||
operations = OperationGenerator().parse(openapi_spec_dict)
|
||||
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
|
||||
```
|
||||
"""
|
||||
|
||||
EXCLUDE_FIELDS = [
|
||||
'connection_name',
|
||||
'service_name',
|
||||
'host',
|
||||
'entity',
|
||||
'operation',
|
||||
'action',
|
||||
]
|
||||
|
||||
OPTIONAL_FIELDS = [
|
||||
'page_size',
|
||||
'page_token',
|
||||
'filter',
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
connection_name: str,
|
||||
connection_host: str,
|
||||
connection_service_name: str,
|
||||
entity: str,
|
||||
operation: str,
|
||||
action: str,
|
||||
rest_api_tool: RestApiTool,
|
||||
):
|
||||
"""Initializes the ApplicationIntegrationTool.
|
||||
|
||||
Args:
|
||||
name: The name of the tool, typically derived from the API operation.
|
||||
Should be unique and adhere to Gemini function naming conventions
|
||||
(e.g., less than 64 characters).
|
||||
description: A description of what the tool does, usually based on the
|
||||
API operation's summary or description.
|
||||
connection_name: The name of the Integration Connector connection.
|
||||
connection_host: The hostname or IP address for the connection.
|
||||
connection_service_name: The specific service name within the host.
|
||||
entity: The Integration Connector entity being targeted.
|
||||
operation: The specific operation being performed on the entity.
|
||||
action: The action associated with the operation (e.g., 'execute').
|
||||
rest_api_tool: An initialized RestApiTool instance that handles the
|
||||
underlying REST API communication based on an OpenAPI specification
|
||||
operation. This tool will be called by ApplicationIntegrationTool with
|
||||
added connection and context arguments. tool =
|
||||
[RestApiTool.from_parsed_operation(o) for o in operations]
|
||||
"""
|
||||
# Gemini restrict the length of function name to be less than 64 characters
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
self.connection_name = connection_name
|
||||
self.connection_host = connection_host
|
||||
self.connection_service_name = connection_service_name
|
||||
self.entity = entity
|
||||
self.operation = operation
|
||||
self.action = action
|
||||
self.rest_api_tool = rest_api_tool
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
"""Returns the function declaration in the Gemini Schema format."""
|
||||
schema_dict = self.rest_api_tool._operation_parser.get_json_schema()
|
||||
for field in self.EXCLUDE_FIELDS:
|
||||
if field in schema_dict['properties']:
|
||||
del schema_dict['properties'][field]
|
||||
for field in self.OPTIONAL_FIELDS + self.EXCLUDE_FIELDS:
|
||||
if field in schema_dict['required']:
|
||||
schema_dict['required'].remove(field)
|
||||
|
||||
parameters = to_gemini_schema(schema_dict)
|
||||
function_decl = FunctionDeclaration(
|
||||
name=self.name, description=self.description, parameters=parameters
|
||||
)
|
||||
return function_decl
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
|
||||
) -> Dict[str, Any]:
|
||||
args['connection_name'] = self.connection_name
|
||||
args['service_name'] = self.connection_service_name
|
||||
args['host'] = self.connection_host
|
||||
args['entity'] = self.entity
|
||||
args['operation'] = self.operation
|
||||
args['action'] = self.action
|
||||
logger.info('Running tool: %s with args: %s', self.name, args)
|
||||
return self.rest_api_tool.call(args=args, tool_context=tool_context)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f'ApplicationIntegrationTool(name="{self.name}",'
|
||||
f' description="{self.description}",'
|
||||
f' connection_name="{self.connection_name}", entity="{self.entity}",'
|
||||
f' operation="{self.operation}", action="{self.action}")'
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'ApplicationIntegrationTool(name="{self.name}",'
|
||||
f' description="{self.description}",'
|
||||
f' connection_name="{self.connection_name}",'
|
||||
f' connection_host="{self.connection_host}",'
|
||||
f' connection_service_name="{self.connection_service_name}",'
|
||||
f' entity="{self.entity}", operation="{self.operation}",'
|
||||
f' action="{self.action}", rest_api_tool={repr(self.rest_api_tool)})'
|
||||
)
|
||||
@@ -59,6 +59,23 @@ class FunctionTool(BaseTool):
|
||||
if 'tool_context' in signature.parameters:
|
||||
args_to_call['tool_context'] = tool_context
|
||||
|
||||
# Before invoking the function, we check for if the list of args passed in
|
||||
# has all the mandatory arguments or not.
|
||||
# If the check fails, then we don't invoke the tool and let the Agent know
|
||||
# that there was a missing a input parameter. This will basically help
|
||||
# the underlying model fix the issue and retry.
|
||||
mandatory_args = self._get_mandatory_args()
|
||||
missing_mandatory_args = [
|
||||
arg for arg in mandatory_args if arg not in args_to_call
|
||||
]
|
||||
|
||||
if missing_mandatory_args:
|
||||
missing_mandatory_args_str = '\n'.join(missing_mandatory_args)
|
||||
error_str = f"""Invoking `{self.name}()` failed as the following mandatory input parameters are not present:
|
||||
{missing_mandatory_args_str}
|
||||
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||
return {'error': error_str}
|
||||
|
||||
if inspect.iscoroutinefunction(self.func):
|
||||
return await self.func(**args_to_call) or {}
|
||||
else:
|
||||
@@ -85,3 +102,28 @@ class FunctionTool(BaseTool):
|
||||
args_to_call['tool_context'] = tool_context
|
||||
async for item in self.func(**args_to_call):
|
||||
yield item
|
||||
|
||||
def _get_mandatory_args(
|
||||
self,
|
||||
) -> list[str]:
|
||||
"""Identifies mandatory parameters (those without default values) for a function.
|
||||
|
||||
Returns:
|
||||
A list of strings, where each string is the name of a mandatory parameter.
|
||||
"""
|
||||
signature = inspect.signature(self.func)
|
||||
mandatory_params = []
|
||||
|
||||
for name, param in signature.parameters.items():
|
||||
# A parameter is mandatory if:
|
||||
# 1. It has no default value (param.default is inspect.Parameter.empty)
|
||||
# 2. It's not a variable positional (*args) or variable keyword (**kwargs) parameter
|
||||
#
|
||||
# For more refer to: https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind
|
||||
if param.default == inspect.Parameter.empty and param.kind not in (
|
||||
inspect.Parameter.VAR_POSITIONAL,
|
||||
inspect.Parameter.VAR_KEYWORD,
|
||||
):
|
||||
mandatory_params.append(name)
|
||||
|
||||
return mandatory_params
|
||||
|
||||
@@ -11,10 +11,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Final
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
@@ -28,6 +30,7 @@ from .googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
|
||||
|
||||
|
||||
class GoogleApiToolSet:
|
||||
"""Google API Tool Set."""
|
||||
|
||||
def __init__(self, tools: List[RestApiTool]):
|
||||
self.tools: Final[List[GoogleApiTool]] = [
|
||||
@@ -45,10 +48,10 @@ class GoogleApiToolSet:
|
||||
|
||||
@staticmethod
|
||||
def _load_tool_set_with_oidc_auth(
|
||||
spec_file: str = None,
|
||||
spec_dict: Dict[str, Any] = None,
|
||||
scopes: list[str] = None,
|
||||
) -> Optional[OpenAPIToolset]:
|
||||
spec_file: Optional[str] = None,
|
||||
spec_dict: Optional[dict[str, Any]] = None,
|
||||
scopes: Optional[list[str]] = None,
|
||||
) -> OpenAPIToolset:
|
||||
spec_str = None
|
||||
if spec_file:
|
||||
# Get the frame of the caller
|
||||
@@ -90,18 +93,18 @@ class GoogleApiToolSet:
|
||||
|
||||
@classmethod
|
||||
def load_tool_set(
|
||||
cl: Type['GoogleApiToolSet'],
|
||||
cls: Type[GoogleApiToolSet],
|
||||
api_name: str,
|
||||
api_version: str,
|
||||
) -> 'GoogleApiToolSet':
|
||||
) -> GoogleApiToolSet:
|
||||
spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert()
|
||||
scope = list(
|
||||
spec_dict['components']['securitySchemes']['oauth2']['flows'][
|
||||
'authorizationCode'
|
||||
]['scopes'].keys()
|
||||
)[0]
|
||||
return cl(
|
||||
cl._load_tool_set_with_oidc_auth(
|
||||
return cls(
|
||||
cls._load_tool_set_with_oidc_auth(
|
||||
spec_dict=spec_dict, scopes=[scope]
|
||||
).get_tools()
|
||||
)
|
||||
|
||||
@@ -89,7 +89,7 @@ class LoadArtifactsTool(BaseTool):
|
||||
than the function call.
|
||||
"""])
|
||||
|
||||
# Attache the content of the artifacts if the model requests them.
|
||||
# Attach the content of the artifacts if the model requests them.
|
||||
# This only adds the content to the model request, instead of the session.
|
||||
if llm_request.contents and llm_request.contents[-1].parts:
|
||||
function_response = llm_request.contents[-1].parts[0].function_response
|
||||
|
||||
@@ -66,10 +66,10 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
|
||||
|
||||
Returns:
|
||||
An AuthCredential object containing the HTTP bearer access token. If the
|
||||
HTTO bearer token cannot be generated, return the origianl credential
|
||||
HTTP bearer token cannot be generated, return the original credential.
|
||||
"""
|
||||
|
||||
if "access_token" not in auth_credential.oauth2.token:
|
||||
if not auth_credential.oauth2.access_token:
|
||||
return auth_credential
|
||||
|
||||
# Return the access token as a bearer token.
|
||||
@@ -78,7 +78,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
|
||||
http=HttpAuth(
|
||||
scheme="bearer",
|
||||
credentials=HttpCredentials(
|
||||
token=auth_credential.oauth2.token["access_token"]
|
||||
token=auth_credential.oauth2.access_token
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -111,7 +111,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
|
||||
return auth_credential
|
||||
|
||||
# If access token is exchanged, exchange a HTTPBearer token.
|
||||
if auth_credential.oauth2.token:
|
||||
if auth_credential.oauth2.access_token:
|
||||
return self.generate_auth_token(auth_credential)
|
||||
|
||||
return None
|
||||
|
||||
@@ -124,7 +124,7 @@ class OpenAPIToolset:
|
||||
def _load_spec(
|
||||
self, spec_str: str, spec_type: Literal["json", "yaml"]
|
||||
) -> Dict[str, Any]:
|
||||
"""Loads the OpenAPI spec string into adictionary."""
|
||||
"""Loads the OpenAPI spec string into a dictionary."""
|
||||
if spec_type == "json":
|
||||
return json.loads(spec_str)
|
||||
elif spec_type == "yaml":
|
||||
|
||||
@@ -14,20 +14,12 @@
|
||||
|
||||
import inspect
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.openapi.models import Operation
|
||||
from fastapi.openapi.models import Parameter
|
||||
from fastapi.openapi.models import Schema
|
||||
from fastapi.openapi.models import Operation, Parameter, Schema
|
||||
|
||||
from ..common.common import ApiParameter
|
||||
from ..common.common import PydocHelper
|
||||
from ..common.common import to_snake_case
|
||||
from ..common.common import ApiParameter, PydocHelper, to_snake_case
|
||||
|
||||
|
||||
class OperationParser:
|
||||
@@ -113,7 +105,8 @@ class OperationParser:
|
||||
description = request_body.description or ''
|
||||
|
||||
if schema and schema.type == 'object':
|
||||
for prop_name, prop_details in schema.properties.items():
|
||||
properties = schema.properties or {}
|
||||
for prop_name, prop_details in properties.items():
|
||||
self.params.append(
|
||||
ApiParameter(
|
||||
original_name=prop_name,
|
||||
|
||||
@@ -17,6 +17,7 @@ from typing import Dict
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
@@ -59,6 +60,40 @@ def snake_to_lower_camel(snake_case_string: str):
|
||||
])
|
||||
|
||||
|
||||
# TODO: Switch to Gemini `from_json_schema` util when it is released
|
||||
# in Gemini SDK.
|
||||
def normalize_json_schema_type(
|
||||
json_schema_type: Optional[Union[str, Sequence[str]]],
|
||||
) -> tuple[Optional[str], bool]:
|
||||
"""Converts a JSON Schema Type into Gemini Schema type.
|
||||
|
||||
Adopted and modified from Gemini SDK. This gets the first available schema
|
||||
type from JSON Schema, and use it to mark Gemini schema type. If JSON Schema
|
||||
contains a list of types, the first non null type is used.
|
||||
|
||||
Remove this after switching to Gemini `from_json_schema`.
|
||||
"""
|
||||
if json_schema_type is None:
|
||||
return None, False
|
||||
if isinstance(json_schema_type, str):
|
||||
if json_schema_type == "null":
|
||||
return None, True
|
||||
return json_schema_type, False
|
||||
|
||||
non_null_types = []
|
||||
nullable = False
|
||||
# If json schema type is an array, pick the first non null type.
|
||||
for type_value in json_schema_type:
|
||||
if type_value == "null":
|
||||
nullable = True
|
||||
else:
|
||||
non_null_types.append(type_value)
|
||||
non_null_type = non_null_types[0] if non_null_types else None
|
||||
return non_null_type, nullable
|
||||
|
||||
|
||||
# TODO: Switch to Gemini `from_json_schema` util when it is released
|
||||
# in Gemini SDK.
|
||||
def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
|
||||
"""Converts an OpenAPI schema dictionary to a Gemini Schema object.
|
||||
|
||||
@@ -82,13 +117,6 @@ def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
|
||||
if not openapi_schema.get("type"):
|
||||
openapi_schema["type"] = "object"
|
||||
|
||||
# Adding this to avoid "properties: should be non-empty for OBJECT type" error
|
||||
# See b/385165182
|
||||
if openapi_schema.get("type", "") == "object" and not openapi_schema.get(
|
||||
"properties"
|
||||
):
|
||||
openapi_schema["properties"] = {"dummy_DO_NOT_GENERATE": {"type": "string"}}
|
||||
|
||||
for key, value in openapi_schema.items():
|
||||
snake_case_key = to_snake_case(key)
|
||||
# Check if the snake_case_key exists in the Schema model's fields.
|
||||
@@ -99,7 +127,17 @@ def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
|
||||
# Format: properties[expiration].format: only 'enum' and 'date-time' are
|
||||
# supported for STRING type
|
||||
continue
|
||||
if snake_case_key == "properties" and isinstance(value, dict):
|
||||
elif snake_case_key == "type":
|
||||
schema_type, nullable = normalize_json_schema_type(
|
||||
openapi_schema.get("type", None)
|
||||
)
|
||||
# Adding this to force adding a type to an empty dict
|
||||
# This avoid "... one_of or any_of must specify a type" error
|
||||
pydantic_schema_data["type"] = schema_type if schema_type else "object"
|
||||
pydantic_schema_data["type"] = pydantic_schema_data["type"].upper()
|
||||
if nullable:
|
||||
pydantic_schema_data["nullable"] = True
|
||||
elif snake_case_key == "properties" and isinstance(value, dict):
|
||||
pydantic_schema_data[snake_case_key] = {
|
||||
k: to_gemini_schema(v) for k, v in value.items()
|
||||
}
|
||||
|
||||
@@ -13,4 +13,4 @@
|
||||
# limitations under the License.
|
||||
|
||||
# version: date+base_cl
|
||||
__version__ = "0.1.1"
|
||||
__version__ = "0.3.0"
|
||||
|
||||
Reference in New Issue
Block a user