mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2026-02-04 13:56:24 -06:00
No public description
PiperOrigin-RevId: 748777998
This commit is contained in:
committed by
hangfei
parent
290058eb05
commit
61d4be2d76
File diff suppressed because one or more lines are too long
+65
-81
File diff suppressed because one or more lines are too long
@@ -82,6 +82,8 @@ async def run_interactively(
|
||||
)
|
||||
while True:
|
||||
query = input('user: ')
|
||||
if not query or not query.strip():
|
||||
continue
|
||||
if query == 'exit':
|
||||
break
|
||||
async for event in runner.run_async(
|
||||
|
||||
@@ -0,0 +1,279 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import click
|
||||
|
||||
_INIT_PY_TEMPLATE = """\
|
||||
from . import agent
|
||||
"""
|
||||
|
||||
_AGENT_PY_TEMPLATE = """\
|
||||
from google.adk.agents import Agent
|
||||
|
||||
root_agent = Agent(
|
||||
model='{model_name}',
|
||||
name='root_agent',
|
||||
description='A helpful assistant for user questions.',
|
||||
instruction='Answer user questions to the best of your knowledge',
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
_GOOGLE_API_MSG = """
|
||||
Don't have API Key? Create one in AI Studio: https://aistudio.google.com/apikey
|
||||
"""
|
||||
|
||||
_GOOGLE_CLOUD_SETUP_MSG = """
|
||||
You need an existing Google Cloud account and project, check out this link for details:
|
||||
https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai
|
||||
"""
|
||||
|
||||
_OTHER_MODEL_MSG = """
|
||||
Please see below guide to configure other models:
|
||||
https://google.github.io/adk-docs/agents/models
|
||||
"""
|
||||
|
||||
_SUCCESS_MSG = """
|
||||
Agent created in {agent_folder}:
|
||||
- .env
|
||||
- __init__.py
|
||||
- agent.py
|
||||
"""
|
||||
|
||||
|
||||
def _get_gcp_project_from_gcloud() -> str:
|
||||
"""Uses gcloud to get default project."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["gcloud", "config", "get-value", "project"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return result.stdout.strip()
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return ""
|
||||
|
||||
|
||||
def _get_gcp_region_from_gcloud() -> str:
|
||||
"""Uses gcloud to get default region."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["gcloud", "config", "get-value", "compute/region"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return result.stdout.strip()
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return ""
|
||||
|
||||
|
||||
def _prompt_str(
|
||||
prompt_prefix: str,
|
||||
*,
|
||||
prior_msg: Optional[str] = None,
|
||||
default_value: Optional[str] = None,
|
||||
) -> str:
|
||||
if prior_msg:
|
||||
click.secho(prior_msg, fg="green")
|
||||
while True:
|
||||
value: str = click.prompt(
|
||||
prompt_prefix, default=default_value or None, type=str
|
||||
)
|
||||
if value and value.strip():
|
||||
return value.strip()
|
||||
|
||||
|
||||
def _prompt_for_google_cloud(
|
||||
google_cloud_project: Optional[str],
|
||||
) -> str:
|
||||
"""Prompts user for Google Cloud project ID."""
|
||||
google_cloud_project = (
|
||||
google_cloud_project
|
||||
or os.environ.get("GOOGLE_CLOUD_PROJECT", None)
|
||||
or _get_gcp_project_from_gcloud()
|
||||
)
|
||||
|
||||
google_cloud_project = _prompt_str(
|
||||
"Enter Google Cloud project ID", default_value=google_cloud_project
|
||||
)
|
||||
|
||||
return google_cloud_project
|
||||
|
||||
|
||||
def _prompt_for_google_cloud_region(
|
||||
google_cloud_region: Optional[str],
|
||||
) -> str:
|
||||
"""Prompts user for Google Cloud region."""
|
||||
google_cloud_region = (
|
||||
google_cloud_region
|
||||
or os.environ.get("GOOGLE_CLOUD_LOCATION", None)
|
||||
or _get_gcp_region_from_gcloud()
|
||||
)
|
||||
|
||||
google_cloud_region = _prompt_str(
|
||||
"Enter Google Cloud region",
|
||||
default_value=google_cloud_region or "us-central1",
|
||||
)
|
||||
return google_cloud_region
|
||||
|
||||
|
||||
def _prompt_for_google_api_key(
|
||||
google_api_key: Optional[str],
|
||||
) -> str:
|
||||
"""Prompts user for Google API key."""
|
||||
google_api_key = google_api_key or os.environ.get("GOOGLE_API_KEY", None)
|
||||
|
||||
google_api_key = _prompt_str(
|
||||
"Enter Google API key",
|
||||
prior_msg=_GOOGLE_API_MSG,
|
||||
default_value=google_api_key,
|
||||
)
|
||||
return google_api_key
|
||||
|
||||
|
||||
def _generate_files(
|
||||
agent_folder: str,
|
||||
*,
|
||||
google_api_key: Optional[str] = None,
|
||||
google_cloud_project: Optional[str] = None,
|
||||
google_cloud_region: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
"""Generates a folder name for the agent."""
|
||||
os.makedirs(agent_folder, exist_ok=True)
|
||||
|
||||
dotenv_file_path = os.path.join(agent_folder, ".env")
|
||||
init_file_path = os.path.join(agent_folder, "__init__.py")
|
||||
agent_file_path = os.path.join(agent_folder, "agent.py")
|
||||
|
||||
with open(dotenv_file_path, "w", encoding="utf-8") as f:
|
||||
lines = []
|
||||
if google_api_key:
|
||||
lines.append("GOOGLE_GENAI_USE_VERTEXAI=0")
|
||||
elif google_cloud_project and google_cloud_region:
|
||||
lines.append("GOOGLE_GENAI_USE_VERTEXAI=1")
|
||||
if google_api_key:
|
||||
lines.append(f"GOOGLE_API_KEY={google_api_key}")
|
||||
if google_cloud_project:
|
||||
lines.append(f"GOOGLE_CLOUD_PROJECT={google_cloud_project}")
|
||||
if google_cloud_region:
|
||||
lines.append(f"GOOGLE_CLOUD_LOCATION={google_cloud_region}")
|
||||
f.write("\n".join(lines))
|
||||
|
||||
with open(init_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(_INIT_PY_TEMPLATE)
|
||||
|
||||
with open(agent_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(_AGENT_PY_TEMPLATE.format(model_name=model))
|
||||
|
||||
click.secho(
|
||||
_SUCCESS_MSG.format(agent_folder=agent_folder),
|
||||
fg="green",
|
||||
)
|
||||
|
||||
|
||||
def _prompt_for_model() -> str:
|
||||
model_choice = click.prompt(
|
||||
"""\
|
||||
Choose a model for the root agent:
|
||||
1. gemini-2.0-flash-001
|
||||
2. Other models (fill later)
|
||||
Choose model""",
|
||||
type=click.Choice(["1", "2"]),
|
||||
)
|
||||
if model_choice == "1":
|
||||
return "gemini-2.0-flash-001"
|
||||
else:
|
||||
click.secho(_OTHER_MODEL_MSG, fg="green")
|
||||
return "<FILL_IN_MODEL>"
|
||||
|
||||
|
||||
def _prompt_to_choose_backend(
|
||||
google_api_key: Optional[str],
|
||||
google_cloud_project: Optional[str],
|
||||
google_cloud_region: Optional[str],
|
||||
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""Prompts user to choose backend.
|
||||
|
||||
Returns:
|
||||
A tuple of (google_api_key, google_cloud_project, google_cloud_region).
|
||||
"""
|
||||
backend_choice = click.prompt(
|
||||
"1. Google AI\n2. Vertex AI\nChoose a backend",
|
||||
type=click.Choice(["1", "2"]),
|
||||
)
|
||||
if backend_choice == "1":
|
||||
google_api_key = _prompt_for_google_api_key(google_api_key)
|
||||
elif backend_choice == "2":
|
||||
click.secho(_GOOGLE_CLOUD_SETUP_MSG, fg="green")
|
||||
google_cloud_project = _prompt_for_google_cloud(google_cloud_project)
|
||||
google_cloud_region = _prompt_for_google_cloud_region(google_cloud_region)
|
||||
return google_api_key, google_cloud_project, google_cloud_region
|
||||
|
||||
|
||||
def run_cmd(
|
||||
agent_name: str,
|
||||
*,
|
||||
model: Optional[str],
|
||||
google_api_key: Optional[str],
|
||||
google_cloud_project: Optional[str],
|
||||
google_cloud_region: Optional[str],
|
||||
):
|
||||
"""Runs `adk create` command to create agent template.
|
||||
|
||||
Args:
|
||||
agent_name: str, The name of the agent.
|
||||
google_api_key: Optional[str], The Google API key for using Google AI as
|
||||
backend.
|
||||
google_cloud_project: Optional[str], The Google Cloud project for using
|
||||
VertexAI as backend.
|
||||
google_cloud_region: Optional[str], The Google Cloud region for using
|
||||
VertexAI as backend.
|
||||
"""
|
||||
agent_folder = os.path.join(os.getcwd(), agent_name)
|
||||
# check folder doesn't exist or it's empty. Otherwise, throw
|
||||
if os.path.exists(agent_folder) and os.listdir(agent_folder):
|
||||
# Prompt user whether to override existing files using click
|
||||
if not click.confirm(
|
||||
f"Non-empty folder already exist: '{agent_folder}'\n"
|
||||
"Override existing content?",
|
||||
default=False,
|
||||
):
|
||||
raise click.Abort()
|
||||
|
||||
if not model:
|
||||
model = _prompt_for_model()
|
||||
|
||||
if not google_api_key and not (google_cloud_project and google_cloud_region):
|
||||
if model.startswith("gemini"):
|
||||
google_api_key, google_cloud_project, google_cloud_region = (
|
||||
_prompt_to_choose_backend(
|
||||
google_api_key, google_cloud_project, google_cloud_region
|
||||
)
|
||||
)
|
||||
|
||||
_generate_files(
|
||||
agent_folder,
|
||||
google_api_key=google_api_key,
|
||||
google_cloud_project=google_cloud_project,
|
||||
google_cloud_region=google_cloud_region,
|
||||
model=model,
|
||||
)
|
||||
@@ -82,8 +82,9 @@ def to_cloud_run(
|
||||
app_name: str,
|
||||
temp_folder: str,
|
||||
port: int,
|
||||
with_cloud_trace: bool,
|
||||
trace_to_cloud: bool,
|
||||
with_ui: bool,
|
||||
verbosity: str,
|
||||
):
|
||||
"""Deploys an agent to Google Cloud Run.
|
||||
|
||||
@@ -108,8 +109,9 @@ def to_cloud_run(
|
||||
app_name: The name of the app, by default, it's basename of `agent_folder`.
|
||||
temp_folder: The temp folder for the generated Cloud Run source files.
|
||||
port: The port of the ADK api server.
|
||||
with_cloud_trace: Whether to enable Cloud Trace.
|
||||
trace_to_cloud: Whether to enable Cloud Trace.
|
||||
with_ui: Whether to deploy with UI.
|
||||
verbosity: The verbosity level of the CLI.
|
||||
"""
|
||||
app_name = app_name or os.path.basename(agent_folder)
|
||||
|
||||
@@ -142,7 +144,7 @@ def to_cloud_run(
|
||||
port=port,
|
||||
command='web' if with_ui else 'api_server',
|
||||
install_agent_deps=install_agent_deps,
|
||||
trace_to_cloud_option='--trace_to_cloud' if with_cloud_trace else '',
|
||||
trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '',
|
||||
)
|
||||
dockerfile_path = os.path.join(temp_folder, 'Dockerfile')
|
||||
os.makedirs(temp_folder, exist_ok=True)
|
||||
@@ -169,6 +171,8 @@ def to_cloud_run(
|
||||
*region_options,
|
||||
'--port',
|
||||
str(port),
|
||||
'--verbosity',
|
||||
verbosity,
|
||||
'--labels',
|
||||
'created-by=adk',
|
||||
],
|
||||
|
||||
@@ -24,6 +24,7 @@ import click
|
||||
from fastapi import FastAPI
|
||||
import uvicorn
|
||||
|
||||
from . import cli_create
|
||||
from . import cli_deploy
|
||||
from .cli import run_cli
|
||||
from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE
|
||||
@@ -42,10 +43,59 @@ def main():
|
||||
|
||||
@main.group()
|
||||
def deploy():
|
||||
"""Deploy Agent."""
|
||||
"""Deploys agent to hosted environments."""
|
||||
pass
|
||||
|
||||
|
||||
@main.command("create")
|
||||
@click.option(
|
||||
"--model",
|
||||
type=str,
|
||||
help="Optional. The model used for the root agent.",
|
||||
)
|
||||
@click.option(
|
||||
"--api_key",
|
||||
type=str,
|
||||
help=(
|
||||
"Optional. The API Key needed to access the model, e.g. Google AI API"
|
||||
" Key."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--project",
|
||||
type=str,
|
||||
help="Optional. The Google Cloud Project for using VertexAI as backend.",
|
||||
)
|
||||
@click.option(
|
||||
"--region",
|
||||
type=str,
|
||||
help="Optional. The Google Cloud Region for using VertexAI as backend.",
|
||||
)
|
||||
@click.argument("app_name", type=str, required=True)
|
||||
def cli_create_cmd(
|
||||
app_name: str,
|
||||
model: Optional[str],
|
||||
api_key: Optional[str],
|
||||
project: Optional[str],
|
||||
region: Optional[str],
|
||||
):
|
||||
"""Creates a new app in the current folder with prepopulated agent template.
|
||||
|
||||
APP_NAME: required, the folder of the agent source code.
|
||||
|
||||
Example:
|
||||
|
||||
adk create path/to/my_app
|
||||
"""
|
||||
cli_create.run_cmd(
|
||||
app_name,
|
||||
model=model,
|
||||
google_api_key=api_key,
|
||||
google_cloud_project=project,
|
||||
google_cloud_region=region,
|
||||
)
|
||||
|
||||
|
||||
@main.command("run")
|
||||
@click.option(
|
||||
"--save_session",
|
||||
@@ -62,7 +112,7 @@ def deploy():
|
||||
),
|
||||
)
|
||||
def cli_run(agent: str, save_session: bool):
|
||||
"""Run an interactive CLI for a certain agent.
|
||||
"""Runs an interactive CLI for a certain agent.
|
||||
|
||||
AGENT: The path to the agent source code folder.
|
||||
|
||||
@@ -150,7 +200,7 @@ def cli_eval(
|
||||
EvalMetric(metric_name=metric_name, threshold=threshold)
|
||||
)
|
||||
|
||||
print(f"Using evaluation criteria: {evaluation_criteria}")
|
||||
print(f"Using evaluation creiteria: {evaluation_criteria}")
|
||||
|
||||
root_agent = get_root_agent(agent_module_file_path)
|
||||
reset_func = try_get_reset_func(agent_module_file_path)
|
||||
@@ -244,7 +294,7 @@ def cli_eval(
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
default=os.getcwd(),
|
||||
default=os.getcwd,
|
||||
)
|
||||
def cli_web(
|
||||
agents_dir: str,
|
||||
@@ -255,7 +305,7 @@ def cli_web(
|
||||
port: int = 8000,
|
||||
trace_to_cloud: bool = False,
|
||||
):
|
||||
"""Start a FastAPI server with Web UI for agents.
|
||||
"""Starts a FastAPI server with Web UI for agents.
|
||||
|
||||
AGENTS_DIR: The directory of agents, where each sub-directory is a single
|
||||
agent, containing at least `__init__.py` and `agent.py` files.
|
||||
@@ -274,7 +324,7 @@ def cli_web(
|
||||
@asynccontextmanager
|
||||
async def _lifespan(app: FastAPI):
|
||||
click.secho(
|
||||
f"""\
|
||||
f"""
|
||||
+-----------------------------------------------------------------------------+
|
||||
| ADK Web Server started |
|
||||
| |
|
||||
@@ -285,7 +335,7 @@ def cli_web(
|
||||
)
|
||||
yield # Startup is done, now app is running
|
||||
click.secho(
|
||||
"""\
|
||||
"""
|
||||
+-----------------------------------------------------------------------------+
|
||||
| ADK Web Server shutting down... |
|
||||
+-----------------------------------------------------------------------------+
|
||||
@@ -378,7 +428,7 @@ def cli_api_server(
|
||||
port: int = 8000,
|
||||
trace_to_cloud: bool = False,
|
||||
):
|
||||
"""Start a FastAPI server for agents.
|
||||
"""Starts a FastAPI server for agents.
|
||||
|
||||
AGENTS_DIR: The directory of agents, where each sub-directory is a single
|
||||
agent, containing at least `__init__.py` and `agent.py` files.
|
||||
@@ -452,7 +502,7 @@ def cli_api_server(
|
||||
help="Optional. The port of the ADK API server (default: 8000).",
|
||||
)
|
||||
@click.option(
|
||||
"--with_cloud_trace",
|
||||
"--trace_to_cloud",
|
||||
type=bool,
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
@@ -483,6 +533,14 @@ def cli_api_server(
|
||||
" (default: a timestamped folder in the system temp directory)."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--verbosity",
|
||||
type=click.Choice(
|
||||
["debug", "info", "warning", "error", "critical"], case_sensitive=False
|
||||
),
|
||||
default="WARNING",
|
||||
help="Optional. Override the default verbosity level.",
|
||||
)
|
||||
@click.argument(
|
||||
"agent",
|
||||
type=click.Path(
|
||||
@@ -497,8 +555,9 @@ def cli_deploy_cloud_run(
|
||||
app_name: str,
|
||||
temp_folder: str,
|
||||
port: int,
|
||||
with_cloud_trace: bool,
|
||||
trace_to_cloud: bool,
|
||||
with_ui: bool,
|
||||
verbosity: str,
|
||||
):
|
||||
"""Deploys an agent to Cloud Run.
|
||||
|
||||
@@ -517,8 +576,9 @@ def cli_deploy_cloud_run(
|
||||
app_name=app_name,
|
||||
temp_folder=temp_folder,
|
||||
port=port,
|
||||
with_cloud_trace=with_cloud_trace,
|
||||
trace_to_cloud=trace_to_cloud,
|
||||
with_ui=with_ui,
|
||||
verbosity=verbosity,
|
||||
)
|
||||
except Exception as e:
|
||||
click.secho(f"Deploy failed: {e}", fg="red", err=True)
|
||||
|
||||
@@ -13,7 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -28,6 +30,7 @@ from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from click import Tuple
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
@@ -56,6 +59,7 @@ from ..agents.llm_agent import Agent
|
||||
from ..agents.run_config import StreamingMode
|
||||
from ..artifacts import InMemoryArtifactService
|
||||
from ..events.event import Event
|
||||
from ..memory.in_memory_memory_service import InMemoryMemoryService
|
||||
from ..runners import Runner
|
||||
from ..sessions.database_session_service import DatabaseSessionService
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
@@ -143,11 +147,8 @@ def get_fast_api_app(
|
||||
provider.add_span_processor(
|
||||
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
|
||||
)
|
||||
envs.load_dotenv()
|
||||
enable_cloud_tracing = trace_to_cloud or os.environ.get(
|
||||
"ADK_TRACE_TO_CLOUD", "0"
|
||||
).lower() in ["1", "true"]
|
||||
if enable_cloud_tracing:
|
||||
if trace_to_cloud:
|
||||
envs.load_dotenv_for_agent("", agent_dir)
|
||||
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
|
||||
processor = export.BatchSpanProcessor(
|
||||
CloudTraceSpanExporter(project_id=project_id)
|
||||
@@ -161,8 +162,22 @@ def get_fast_api_app(
|
||||
|
||||
trace.set_tracer_provider(provider)
|
||||
|
||||
exit_stacks = []
|
||||
|
||||
@asynccontextmanager
|
||||
async def internal_lifespan(app: FastAPI):
|
||||
if lifespan:
|
||||
async with lifespan(app) as lifespan_context:
|
||||
yield
|
||||
|
||||
if exit_stacks:
|
||||
for stack in exit_stacks:
|
||||
await stack.aclose()
|
||||
else:
|
||||
yield
|
||||
|
||||
# Run the FastAPI server.
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app = FastAPI(lifespan=internal_lifespan)
|
||||
|
||||
if allow_origins:
|
||||
app.add_middleware(
|
||||
@@ -181,6 +196,7 @@ def get_fast_api_app(
|
||||
|
||||
# Build the Artifact service
|
||||
artifact_service = InMemoryArtifactService()
|
||||
memory_service = InMemoryMemoryService()
|
||||
|
||||
# Build the Session service
|
||||
agent_engine_id = ""
|
||||
@@ -358,7 +374,7 @@ def get_fast_api_app(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def add_session_to_eval_set(
|
||||
async def add_session_to_eval_set(
|
||||
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
|
||||
):
|
||||
pattern = r"^[a-zA-Z0-9_]+$"
|
||||
@@ -393,7 +409,9 @@ def get_fast_api_app(
|
||||
test_data = evals.convert_session_to_eval_format(session)
|
||||
|
||||
# Populate the session with initial session state.
|
||||
initial_session_state = create_empty_state(_get_root_agent(app_name))
|
||||
initial_session_state = create_empty_state(
|
||||
await _get_root_agent_async(app_name)
|
||||
)
|
||||
|
||||
eval_set_data.append({
|
||||
"name": req.eval_id,
|
||||
@@ -430,7 +448,7 @@ def get_fast_api_app(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def run_eval(
|
||||
async def run_eval(
|
||||
app_name: str, eval_set_id: str, req: RunEvalRequest
|
||||
) -> list[RunEvalResult]:
|
||||
from .cli_eval import run_evals
|
||||
@@ -447,7 +465,7 @@ def get_fast_api_app(
|
||||
logger.info(
|
||||
"Eval ids to run list is empty. We will all evals in the eval set."
|
||||
)
|
||||
root_agent = _get_root_agent(app_name)
|
||||
root_agent = await _get_root_agent_async(app_name)
|
||||
eval_results = list(
|
||||
run_evals(
|
||||
eval_set_to_evals,
|
||||
@@ -577,7 +595,7 @@ def get_fast_api_app(
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
runner = _get_runner(req.app_name)
|
||||
runner = await _get_runner_async(req.app_name)
|
||||
events = [
|
||||
event
|
||||
async for event in runner.run_async(
|
||||
@@ -604,7 +622,7 @@ def get_fast_api_app(
|
||||
async def event_generator():
|
||||
try:
|
||||
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
|
||||
runner = _get_runner(req.app_name)
|
||||
runner = await _get_runner_async(req.app_name)
|
||||
async for event in runner.run_async(
|
||||
user_id=req.user_id,
|
||||
session_id=req.session_id,
|
||||
@@ -630,7 +648,7 @@ def get_fast_api_app(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def get_event_graph(
|
||||
async def get_event_graph(
|
||||
app_name: str, user_id: str, session_id: str, event_id: str
|
||||
):
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
@@ -647,7 +665,7 @@ def get_fast_api_app(
|
||||
|
||||
function_calls = event.get_function_calls()
|
||||
function_responses = event.get_function_responses()
|
||||
root_agent = _get_root_agent(app_name)
|
||||
root_agent = await _get_root_agent_async(app_name)
|
||||
dot_graph = None
|
||||
if function_calls:
|
||||
function_call_highlights = []
|
||||
@@ -704,7 +722,7 @@ def get_fast_api_app(
|
||||
live_request_queue = LiveRequestQueue()
|
||||
|
||||
async def forward_events():
|
||||
runner = _get_runner(app_name)
|
||||
runner = await _get_runner_async(app_name)
|
||||
async for event in runner.run_live(
|
||||
session=session, live_request_queue=live_request_queue
|
||||
):
|
||||
@@ -742,26 +760,40 @@ def get_fast_api_app(
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
def _get_root_agent(app_name: str) -> Agent:
|
||||
async def _get_root_agent_async(app_name: str) -> Agent:
|
||||
"""Returns the root agent for the given app."""
|
||||
if app_name in root_agent_dict:
|
||||
return root_agent_dict[app_name]
|
||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
||||
agent_module = importlib.import_module(app_name)
|
||||
root_agent: Agent = agent_module.agent.root_agent
|
||||
if getattr(agent_module.agent, "root_agent"):
|
||||
root_agent = agent_module.agent.root_agent
|
||||
else:
|
||||
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
|
||||
|
||||
# Handle an awaitable root agent and await for the actual agent.
|
||||
if inspect.isawaitable(root_agent):
|
||||
try:
|
||||
agent, exit_stack = await root_agent
|
||||
exit_stacks.append(exit_stack)
|
||||
root_agent = agent
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"error getting root agent, {e}") from e
|
||||
|
||||
root_agent_dict[app_name] = root_agent
|
||||
return root_agent
|
||||
|
||||
def _get_runner(app_name: str) -> Runner:
|
||||
async def _get_runner_async(app_name: str) -> Runner:
|
||||
"""Returns the runner for the given app."""
|
||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
||||
if app_name in runner_dict:
|
||||
return runner_dict[app_name]
|
||||
root_agent = _get_root_agent(app_name)
|
||||
root_agent = await _get_root_agent_async(app_name)
|
||||
runner = Runner(
|
||||
app_name=agent_engine_id if agent_engine_id else app_name,
|
||||
agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
memory_service=memory_service,
|
||||
)
|
||||
runner_dict[app_name] = runner
|
||||
return runner
|
||||
|
||||
@@ -50,8 +50,5 @@ def load_dotenv_for_agent(
|
||||
agent_name,
|
||||
dotenv_file_path,
|
||||
)
|
||||
logger.info(
|
||||
'Reloaded %s file for %s at %s', filename, agent_name, dotenv_file_path
|
||||
)
|
||||
else:
|
||||
logger.info('No %s file found for %s', filename, agent_name)
|
||||
|
||||
Reference in New Issue
Block a user