No public description

PiperOrigin-RevId: 748777998
This commit is contained in:
Google ADK Member
2025-04-17 19:50:22 +00:00
committed by hangfei
parent 290058eb05
commit 61d4be2d76
99 changed files with 2120 additions and 256 deletions
+52 -20
View File
@@ -13,7 +13,9 @@
# limitations under the License.
import asyncio
from contextlib import asynccontextmanager
import importlib
import inspect
import json
import logging
import os
@@ -28,6 +30,7 @@ from typing import Literal
from typing import Optional
import click
from click import Tuple
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Query
@@ -56,6 +59,7 @@ from ..agents.llm_agent import Agent
from ..agents.run_config import StreamingMode
from ..artifacts import InMemoryArtifactService
from ..events.event import Event
from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..runners import Runner
from ..sessions.database_session_service import DatabaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService
@@ -143,11 +147,8 @@ def get_fast_api_app(
provider.add_span_processor(
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
)
envs.load_dotenv()
enable_cloud_tracing = trace_to_cloud or os.environ.get(
"ADK_TRACE_TO_CLOUD", "0"
).lower() in ["1", "true"]
if enable_cloud_tracing:
if trace_to_cloud:
envs.load_dotenv_for_agent("", agent_dir)
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
processor = export.BatchSpanProcessor(
CloudTraceSpanExporter(project_id=project_id)
@@ -161,8 +162,22 @@ def get_fast_api_app(
trace.set_tracer_provider(provider)
exit_stacks = []
@asynccontextmanager
async def internal_lifespan(app: FastAPI):
if lifespan:
async with lifespan(app) as lifespan_context:
yield
if exit_stacks:
for stack in exit_stacks:
await stack.aclose()
else:
yield
# Run the FastAPI server.
app = FastAPI(lifespan=lifespan)
app = FastAPI(lifespan=internal_lifespan)
if allow_origins:
app.add_middleware(
@@ -181,6 +196,7 @@ def get_fast_api_app(
# Build the Artifact service
artifact_service = InMemoryArtifactService()
memory_service = InMemoryMemoryService()
# Build the Session service
agent_engine_id = ""
@@ -358,7 +374,7 @@ def get_fast_api_app(
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
response_model_exclude_none=True,
)
def add_session_to_eval_set(
async def add_session_to_eval_set(
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
):
pattern = r"^[a-zA-Z0-9_]+$"
@@ -393,7 +409,9 @@ def get_fast_api_app(
test_data = evals.convert_session_to_eval_format(session)
# Populate the session with initial session state.
initial_session_state = create_empty_state(_get_root_agent(app_name))
initial_session_state = create_empty_state(
await _get_root_agent_async(app_name)
)
eval_set_data.append({
"name": req.eval_id,
@@ -430,7 +448,7 @@ def get_fast_api_app(
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
response_model_exclude_none=True,
)
def run_eval(
async def run_eval(
app_name: str, eval_set_id: str, req: RunEvalRequest
) -> list[RunEvalResult]:
from .cli_eval import run_evals
@@ -447,7 +465,7 @@ def get_fast_api_app(
logger.info(
"Eval ids to run list is empty. We will all evals in the eval set."
)
root_agent = _get_root_agent(app_name)
root_agent = await _get_root_agent_async(app_name)
eval_results = list(
run_evals(
eval_set_to_evals,
@@ -577,7 +595,7 @@ def get_fast_api_app(
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
runner = _get_runner(req.app_name)
runner = await _get_runner_async(req.app_name)
events = [
event
async for event in runner.run_async(
@@ -604,7 +622,7 @@ def get_fast_api_app(
async def event_generator():
try:
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
runner = _get_runner(req.app_name)
runner = await _get_runner_async(req.app_name)
async for event in runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
@@ -630,7 +648,7 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
response_model_exclude_none=True,
)
def get_event_graph(
async def get_event_graph(
app_name: str, user_id: str, session_id: str, event_id: str
):
# Connect to managed session if agent_engine_id is set.
@@ -647,7 +665,7 @@ def get_fast_api_app(
function_calls = event.get_function_calls()
function_responses = event.get_function_responses()
root_agent = _get_root_agent(app_name)
root_agent = await _get_root_agent_async(app_name)
dot_graph = None
if function_calls:
function_call_highlights = []
@@ -704,7 +722,7 @@ def get_fast_api_app(
live_request_queue = LiveRequestQueue()
async def forward_events():
runner = _get_runner(app_name)
runner = await _get_runner_async(app_name)
async for event in runner.run_live(
session=session, live_request_queue=live_request_queue
):
@@ -742,26 +760,40 @@ def get_fast_api_app(
for task in pending:
task.cancel()
def _get_root_agent(app_name: str) -> Agent:
async def _get_root_agent_async(app_name: str) -> Agent:
"""Returns the root agent for the given app."""
if app_name in root_agent_dict:
return root_agent_dict[app_name]
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
agent_module = importlib.import_module(app_name)
root_agent: Agent = agent_module.agent.root_agent
if getattr(agent_module.agent, "root_agent"):
root_agent = agent_module.agent.root_agent
else:
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
# Handle an awaitable root agent and await for the actual agent.
if inspect.isawaitable(root_agent):
try:
agent, exit_stack = await root_agent
exit_stacks.append(exit_stack)
root_agent = agent
except Exception as e:
raise RuntimeError(f"error getting root agent, {e}") from e
root_agent_dict[app_name] = root_agent
return root_agent
def _get_runner(app_name: str) -> Runner:
async def _get_runner_async(app_name: str) -> Runner:
"""Returns the runner for the given app."""
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
if app_name in runner_dict:
return runner_dict[app_name]
root_agent = _get_root_agent(app_name)
root_agent = await _get_root_agent_async(app_name)
runner = Runner(
app_name=agent_engine_id if agent_engine_id else app_name,
agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
memory_service=memory_service,
)
runner_dict[app_name] = runner
return runner