mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2026-02-05 06:16:24 -06:00
No public description
PiperOrigin-RevId: 748777998
This commit is contained in:
committed by
hangfei
parent
290058eb05
commit
61d4be2d76
@@ -13,7 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -28,6 +30,7 @@ from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from click import Tuple
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
@@ -56,6 +59,7 @@ from ..agents.llm_agent import Agent
|
||||
from ..agents.run_config import StreamingMode
|
||||
from ..artifacts import InMemoryArtifactService
|
||||
from ..events.event import Event
|
||||
from ..memory.in_memory_memory_service import InMemoryMemoryService
|
||||
from ..runners import Runner
|
||||
from ..sessions.database_session_service import DatabaseSessionService
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
@@ -143,11 +147,8 @@ def get_fast_api_app(
|
||||
provider.add_span_processor(
|
||||
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
|
||||
)
|
||||
envs.load_dotenv()
|
||||
enable_cloud_tracing = trace_to_cloud or os.environ.get(
|
||||
"ADK_TRACE_TO_CLOUD", "0"
|
||||
).lower() in ["1", "true"]
|
||||
if enable_cloud_tracing:
|
||||
if trace_to_cloud:
|
||||
envs.load_dotenv_for_agent("", agent_dir)
|
||||
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
|
||||
processor = export.BatchSpanProcessor(
|
||||
CloudTraceSpanExporter(project_id=project_id)
|
||||
@@ -161,8 +162,22 @@ def get_fast_api_app(
|
||||
|
||||
trace.set_tracer_provider(provider)
|
||||
|
||||
exit_stacks = []
|
||||
|
||||
@asynccontextmanager
|
||||
async def internal_lifespan(app: FastAPI):
|
||||
if lifespan:
|
||||
async with lifespan(app) as lifespan_context:
|
||||
yield
|
||||
|
||||
if exit_stacks:
|
||||
for stack in exit_stacks:
|
||||
await stack.aclose()
|
||||
else:
|
||||
yield
|
||||
|
||||
# Run the FastAPI server.
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app = FastAPI(lifespan=internal_lifespan)
|
||||
|
||||
if allow_origins:
|
||||
app.add_middleware(
|
||||
@@ -181,6 +196,7 @@ def get_fast_api_app(
|
||||
|
||||
# Build the Artifact service
|
||||
artifact_service = InMemoryArtifactService()
|
||||
memory_service = InMemoryMemoryService()
|
||||
|
||||
# Build the Session service
|
||||
agent_engine_id = ""
|
||||
@@ -358,7 +374,7 @@ def get_fast_api_app(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def add_session_to_eval_set(
|
||||
async def add_session_to_eval_set(
|
||||
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
|
||||
):
|
||||
pattern = r"^[a-zA-Z0-9_]+$"
|
||||
@@ -393,7 +409,9 @@ def get_fast_api_app(
|
||||
test_data = evals.convert_session_to_eval_format(session)
|
||||
|
||||
# Populate the session with initial session state.
|
||||
initial_session_state = create_empty_state(_get_root_agent(app_name))
|
||||
initial_session_state = create_empty_state(
|
||||
await _get_root_agent_async(app_name)
|
||||
)
|
||||
|
||||
eval_set_data.append({
|
||||
"name": req.eval_id,
|
||||
@@ -430,7 +448,7 @@ def get_fast_api_app(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def run_eval(
|
||||
async def run_eval(
|
||||
app_name: str, eval_set_id: str, req: RunEvalRequest
|
||||
) -> list[RunEvalResult]:
|
||||
from .cli_eval import run_evals
|
||||
@@ -447,7 +465,7 @@ def get_fast_api_app(
|
||||
logger.info(
|
||||
"Eval ids to run list is empty. We will all evals in the eval set."
|
||||
)
|
||||
root_agent = _get_root_agent(app_name)
|
||||
root_agent = await _get_root_agent_async(app_name)
|
||||
eval_results = list(
|
||||
run_evals(
|
||||
eval_set_to_evals,
|
||||
@@ -577,7 +595,7 @@ def get_fast_api_app(
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
runner = _get_runner(req.app_name)
|
||||
runner = await _get_runner_async(req.app_name)
|
||||
events = [
|
||||
event
|
||||
async for event in runner.run_async(
|
||||
@@ -604,7 +622,7 @@ def get_fast_api_app(
|
||||
async def event_generator():
|
||||
try:
|
||||
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
|
||||
runner = _get_runner(req.app_name)
|
||||
runner = await _get_runner_async(req.app_name)
|
||||
async for event in runner.run_async(
|
||||
user_id=req.user_id,
|
||||
session_id=req.session_id,
|
||||
@@ -630,7 +648,7 @@ def get_fast_api_app(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def get_event_graph(
|
||||
async def get_event_graph(
|
||||
app_name: str, user_id: str, session_id: str, event_id: str
|
||||
):
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
@@ -647,7 +665,7 @@ def get_fast_api_app(
|
||||
|
||||
function_calls = event.get_function_calls()
|
||||
function_responses = event.get_function_responses()
|
||||
root_agent = _get_root_agent(app_name)
|
||||
root_agent = await _get_root_agent_async(app_name)
|
||||
dot_graph = None
|
||||
if function_calls:
|
||||
function_call_highlights = []
|
||||
@@ -704,7 +722,7 @@ def get_fast_api_app(
|
||||
live_request_queue = LiveRequestQueue()
|
||||
|
||||
async def forward_events():
|
||||
runner = _get_runner(app_name)
|
||||
runner = await _get_runner_async(app_name)
|
||||
async for event in runner.run_live(
|
||||
session=session, live_request_queue=live_request_queue
|
||||
):
|
||||
@@ -742,26 +760,40 @@ def get_fast_api_app(
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
def _get_root_agent(app_name: str) -> Agent:
|
||||
async def _get_root_agent_async(app_name: str) -> Agent:
|
||||
"""Returns the root agent for the given app."""
|
||||
if app_name in root_agent_dict:
|
||||
return root_agent_dict[app_name]
|
||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
||||
agent_module = importlib.import_module(app_name)
|
||||
root_agent: Agent = agent_module.agent.root_agent
|
||||
if getattr(agent_module.agent, "root_agent"):
|
||||
root_agent = agent_module.agent.root_agent
|
||||
else:
|
||||
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
|
||||
|
||||
# Handle an awaitable root agent and await for the actual agent.
|
||||
if inspect.isawaitable(root_agent):
|
||||
try:
|
||||
agent, exit_stack = await root_agent
|
||||
exit_stacks.append(exit_stack)
|
||||
root_agent = agent
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"error getting root agent, {e}") from e
|
||||
|
||||
root_agent_dict[app_name] = root_agent
|
||||
return root_agent
|
||||
|
||||
def _get_runner(app_name: str) -> Runner:
|
||||
async def _get_runner_async(app_name: str) -> Runner:
|
||||
"""Returns the runner for the given app."""
|
||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
||||
if app_name in runner_dict:
|
||||
return runner_dict[app_name]
|
||||
root_agent = _get_root_agent(app_name)
|
||||
root_agent = await _get_root_agent_async(app_name)
|
||||
runner = Runner(
|
||||
app_name=agent_engine_id if agent_engine_id else app_name,
|
||||
agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
memory_service=memory_service,
|
||||
)
|
||||
runner_dict[app_name] = runner
|
||||
return runner
|
||||
|
||||
Reference in New Issue
Block a user