mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
834 lines
27 KiB
Python
834 lines
27 KiB
Python
# Copyright 2025 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
import importlib
|
|
import inspect
|
|
import json
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
import re
|
|
import sys
|
|
import traceback
|
|
import typing
|
|
from typing import Any
|
|
from typing import List
|
|
from typing import Literal
|
|
from typing import Optional
|
|
from typing import Union
|
|
|
|
import click
|
|
from fastapi import FastAPI
|
|
from fastapi import HTTPException
|
|
from fastapi import Query
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse
|
|
from fastapi.responses import RedirectResponse
|
|
from fastapi.responses import StreamingResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.websockets import WebSocket
|
|
from fastapi.websockets import WebSocketDisconnect
|
|
from google.genai import types
|
|
import graphviz
|
|
from opentelemetry import trace
|
|
from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
|
|
from opentelemetry.sdk.trace import export
|
|
from opentelemetry.sdk.trace import ReadableSpan
|
|
from opentelemetry.sdk.trace import TracerProvider
|
|
from pydantic import BaseModel
|
|
from pydantic import ValidationError
|
|
from starlette.types import Lifespan
|
|
|
|
from ..agents import RunConfig
|
|
from ..agents.base_agent import BaseAgent
|
|
from ..agents.live_request_queue import LiveRequest
|
|
from ..agents.live_request_queue import LiveRequestQueue
|
|
from ..agents.llm_agent import Agent
|
|
from ..agents.llm_agent import LlmAgent
|
|
from ..agents.run_config import StreamingMode
|
|
from ..artifacts import InMemoryArtifactService
|
|
from ..events.event import Event
|
|
from ..memory.in_memory_memory_service import InMemoryMemoryService
|
|
from ..runners import Runner
|
|
from ..sessions.database_session_service import DatabaseSessionService
|
|
from ..sessions.in_memory_session_service import InMemorySessionService
|
|
from ..sessions.session import Session
|
|
from ..sessions.vertex_ai_session_service import VertexAiSessionService
|
|
from ..tools.base_toolset import BaseToolset
|
|
from .cli_eval import EVAL_SESSION_ID_PREFIX
|
|
from .cli_eval import EvalMetric
|
|
from .cli_eval import EvalMetricResult
|
|
from .cli_eval import EvalStatus
|
|
from .utils import create_empty_state
|
|
from .utils import envs
|
|
from .utils import evals
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_EVAL_SET_FILE_EXTENSION = ".evalset.json"
|
|
|
|
|
|
class ApiServerSpanExporter(export.SpanExporter):
|
|
|
|
def __init__(self, trace_dict):
|
|
self.trace_dict = trace_dict
|
|
|
|
def export(
|
|
self, spans: typing.Sequence[ReadableSpan]
|
|
) -> export.SpanExportResult:
|
|
for span in spans:
|
|
if (
|
|
span.name == "call_llm"
|
|
or span.name == "send_data"
|
|
or span.name.startswith("tool_response")
|
|
):
|
|
attributes = dict(span.attributes)
|
|
attributes["trace_id"] = span.get_span_context().trace_id
|
|
attributes["span_id"] = span.get_span_context().span_id
|
|
if attributes.get("gcp.vertex.agent.event_id", None):
|
|
self.trace_dict[attributes["gcp.vertex.agent.event_id"]] = attributes
|
|
return export.SpanExportResult.SUCCESS
|
|
|
|
def force_flush(self, timeout_millis: int = 30000) -> bool:
|
|
return True
|
|
|
|
|
|
class AgentRunRequest(BaseModel):
|
|
app_name: str
|
|
user_id: str
|
|
session_id: str
|
|
new_message: types.Content
|
|
streaming: bool = False
|
|
|
|
|
|
class AddSessionToEvalSetRequest(BaseModel):
|
|
eval_id: str
|
|
session_id: str
|
|
user_id: str
|
|
|
|
|
|
class RunEvalRequest(BaseModel):
|
|
eval_ids: list[str] # if empty, then all evals in the eval set are run.
|
|
eval_metrics: list[EvalMetric]
|
|
|
|
|
|
class RunEvalResult(BaseModel):
|
|
eval_set_id: str
|
|
eval_id: str
|
|
final_eval_status: EvalStatus
|
|
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
|
|
session_id: str
|
|
|
|
|
|
def get_fast_api_app(
|
|
*,
|
|
agent_dir: str,
|
|
session_db_url: str = "",
|
|
allow_origins: Optional[list[str]] = None,
|
|
web: bool,
|
|
trace_to_cloud: bool = False,
|
|
lifespan: Optional[Lifespan[FastAPI]] = None,
|
|
) -> FastAPI:
|
|
# InMemory tracing dict.
|
|
trace_dict: dict[str, Any] = {}
|
|
|
|
# Set up tracing in the FastAPI server.
|
|
provider = TracerProvider()
|
|
provider.add_span_processor(
|
|
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
|
|
)
|
|
if trace_to_cloud:
|
|
envs.load_dotenv_for_agent("", agent_dir)
|
|
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
|
|
processor = export.BatchSpanProcessor(
|
|
CloudTraceSpanExporter(project_id=project_id)
|
|
)
|
|
provider.add_span_processor(processor)
|
|
else:
|
|
logging.warning(
|
|
"GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will"
|
|
" not be enabled."
|
|
)
|
|
|
|
trace.set_tracer_provider(provider)
|
|
|
|
exit_stacks = []
|
|
toolsets_to_close: set[BaseToolset] = set()
|
|
|
|
@asynccontextmanager
|
|
async def internal_lifespan(app: FastAPI):
|
|
if lifespan:
|
|
async with lifespan(app) as lifespan_context:
|
|
yield
|
|
|
|
if exit_stacks:
|
|
for stack in exit_stacks:
|
|
await stack.aclose()
|
|
for toolset in toolsets_to_close:
|
|
await toolset.close()
|
|
else:
|
|
yield
|
|
|
|
# Run the FastAPI server.
|
|
app = FastAPI(lifespan=internal_lifespan)
|
|
|
|
if allow_origins:
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=allow_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
if agent_dir not in sys.path:
|
|
sys.path.append(agent_dir)
|
|
|
|
runner_dict = {}
|
|
root_agent_dict = {}
|
|
|
|
# Build the Artifact service
|
|
artifact_service = InMemoryArtifactService()
|
|
memory_service = InMemoryMemoryService()
|
|
|
|
# Build the Session service
|
|
agent_engine_id = ""
|
|
if session_db_url:
|
|
if session_db_url.startswith("agentengine://"):
|
|
# Create vertex session service
|
|
agent_engine_id = session_db_url.split("://")[1]
|
|
if not agent_engine_id:
|
|
raise click.ClickException("Agent engine id can not be empty.")
|
|
envs.load_dotenv_for_agent("", agent_dir)
|
|
session_service = VertexAiSessionService(
|
|
os.environ["GOOGLE_CLOUD_PROJECT"],
|
|
os.environ["GOOGLE_CLOUD_LOCATION"],
|
|
)
|
|
else:
|
|
session_service = DatabaseSessionService(db_url=session_db_url)
|
|
else:
|
|
session_service = InMemorySessionService()
|
|
|
|
@app.get("/list-apps")
|
|
def list_apps() -> list[str]:
|
|
base_path = Path.cwd() / agent_dir
|
|
if not base_path.exists():
|
|
raise HTTPException(status_code=404, detail="Path not found")
|
|
if not base_path.is_dir():
|
|
raise HTTPException(status_code=400, detail="Not a directory")
|
|
agent_names = [
|
|
x
|
|
for x in os.listdir(base_path)
|
|
if os.path.isdir(os.path.join(base_path, x))
|
|
and not x.startswith(".")
|
|
and x != "__pycache__"
|
|
]
|
|
agent_names.sort()
|
|
return agent_names
|
|
|
|
@app.get("/debug/trace/{event_id}")
|
|
def get_trace_dict(event_id: str) -> Any:
|
|
event_dict = trace_dict.get(event_id, None)
|
|
if event_dict is None:
|
|
raise HTTPException(status_code=404, detail="Trace not found")
|
|
return event_dict
|
|
|
|
@app.get(
|
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
|
response_model_exclude_none=True,
|
|
)
|
|
def get_session(app_name: str, user_id: str, session_id: str) -> Session:
|
|
# Connect to managed session if agent_engine_id is set.
|
|
app_name = agent_engine_id if agent_engine_id else app_name
|
|
session = session_service.get_session(
|
|
app_name=app_name, user_id=user_id, session_id=session_id
|
|
)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
return session
|
|
|
|
@app.get(
|
|
"/apps/{app_name}/users/{user_id}/sessions",
|
|
response_model_exclude_none=True,
|
|
)
|
|
def list_sessions(app_name: str, user_id: str) -> list[Session]:
|
|
# Connect to managed session if agent_engine_id is set.
|
|
app_name = agent_engine_id if agent_engine_id else app_name
|
|
return [
|
|
session
|
|
for session in session_service.list_sessions(
|
|
app_name=app_name, user_id=user_id
|
|
).sessions
|
|
# Remove sessions that were generated as a part of Eval.
|
|
if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
|
|
]
|
|
|
|
@app.post(
|
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
|
response_model_exclude_none=True,
|
|
)
|
|
def create_session_with_id(
|
|
app_name: str,
|
|
user_id: str,
|
|
session_id: str,
|
|
state: Optional[dict[str, Any]] = None,
|
|
) -> Session:
|
|
# Connect to managed session if agent_engine_id is set.
|
|
app_name = agent_engine_id if agent_engine_id else app_name
|
|
if (
|
|
session_service.get_session(
|
|
app_name=app_name, user_id=user_id, session_id=session_id
|
|
)
|
|
is not None
|
|
):
|
|
logger.warning("Session already exists: %s", session_id)
|
|
raise HTTPException(
|
|
status_code=400, detail=f"Session already exists: {session_id}"
|
|
)
|
|
|
|
logger.info("New session created: %s", session_id)
|
|
return session_service.create_session(
|
|
app_name=app_name, user_id=user_id, state=state, session_id=session_id
|
|
)
|
|
|
|
@app.post(
|
|
"/apps/{app_name}/users/{user_id}/sessions",
|
|
response_model_exclude_none=True,
|
|
)
|
|
def create_session(
|
|
app_name: str,
|
|
user_id: str,
|
|
state: Optional[dict[str, Any]] = None,
|
|
) -> Session:
|
|
# Connect to managed session if agent_engine_id is set.
|
|
app_name = agent_engine_id if agent_engine_id else app_name
|
|
|
|
logger.info("New session created")
|
|
return session_service.create_session(
|
|
app_name=app_name, user_id=user_id, state=state
|
|
)
|
|
|
|
def _get_eval_set_file_path(app_name, agent_dir, eval_set_id) -> str:
|
|
return os.path.join(
|
|
agent_dir,
|
|
app_name,
|
|
eval_set_id + _EVAL_SET_FILE_EXTENSION,
|
|
)
|
|
|
|
@app.post(
|
|
"/apps/{app_name}/eval_sets/{eval_set_id}",
|
|
response_model_exclude_none=True,
|
|
)
|
|
def create_eval_set(
|
|
app_name: str,
|
|
eval_set_id: str,
|
|
):
|
|
"""Creates an eval set, given the id."""
|
|
pattern = r"^[a-zA-Z0-9_]+$"
|
|
if not bool(re.fullmatch(pattern, eval_set_id)):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=(
|
|
f"Invalid eval set id. Eval set id should have the `{pattern}`"
|
|
" format"
|
|
),
|
|
)
|
|
# Define the file path
|
|
new_eval_set_path = _get_eval_set_file_path(
|
|
app_name, agent_dir, eval_set_id
|
|
)
|
|
|
|
logger.info("Creating eval set file `%s`", new_eval_set_path)
|
|
|
|
if not os.path.exists(new_eval_set_path):
|
|
# Write the JSON string to the file
|
|
logger.info("Eval set file doesn't exist, we will create a new one.")
|
|
with open(new_eval_set_path, "w") as f:
|
|
empty_content = json.dumps([], indent=2)
|
|
f.write(empty_content)
|
|
|
|
@app.get(
|
|
"/apps/{app_name}/eval_sets",
|
|
response_model_exclude_none=True,
|
|
)
|
|
def list_eval_sets(app_name: str) -> list[str]:
|
|
"""Lists all eval sets for the given app."""
|
|
eval_set_file_path = os.path.join(agent_dir, app_name)
|
|
eval_sets = []
|
|
for file in os.listdir(eval_set_file_path):
|
|
if file.endswith(_EVAL_SET_FILE_EXTENSION):
|
|
eval_sets.append(
|
|
os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
|
|
)
|
|
|
|
return sorted(eval_sets)
|
|
|
|
@app.post(
|
|
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
|
|
response_model_exclude_none=True,
|
|
)
|
|
async def add_session_to_eval_set(
|
|
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
|
|
):
|
|
pattern = r"^[a-zA-Z0-9_]+$"
|
|
if not bool(re.fullmatch(pattern, req.eval_id)):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid eval id. Eval id should have the `{pattern}` format",
|
|
)
|
|
|
|
# Get the session
|
|
session = session_service.get_session(
|
|
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
|
)
|
|
assert session, "Session not found."
|
|
# Load the eval set file data
|
|
eval_set_file_path = _get_eval_set_file_path(
|
|
app_name, agent_dir, eval_set_id
|
|
)
|
|
with open(eval_set_file_path, "r") as file:
|
|
eval_set_data = json.load(file) # Load JSON into a list
|
|
|
|
if [x for x in eval_set_data if x["name"] == req.eval_id]:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=(
|
|
f"Eval id `{req.eval_id}` already exists in `{eval_set_id}`"
|
|
" eval set."
|
|
),
|
|
)
|
|
|
|
# Convert the session data to evaluation format
|
|
test_data = evals.convert_session_to_eval_format(session)
|
|
|
|
# Populate the session with initial session state.
|
|
initial_session_state = create_empty_state(
|
|
await _get_root_agent_async(app_name)
|
|
)
|
|
|
|
eval_set_data.append({
|
|
"name": req.eval_id,
|
|
"data": test_data,
|
|
"initial_session": {
|
|
"state": initial_session_state,
|
|
"app_name": app_name,
|
|
"user_id": req.user_id,
|
|
},
|
|
})
|
|
# Serialize the test data to JSON and write to the eval set file.
|
|
with open(eval_set_file_path, "w") as f:
|
|
f.write(json.dumps(eval_set_data, indent=2))
|
|
|
|
@app.get(
|
|
"/apps/{app_name}/eval_sets/{eval_set_id}/evals",
|
|
response_model_exclude_none=True,
|
|
)
|
|
def list_evals_in_eval_set(
|
|
app_name: str,
|
|
eval_set_id: str,
|
|
) -> list[str]:
|
|
"""Lists all evals in an eval set."""
|
|
# Load the eval set file data
|
|
eval_set_file_path = _get_eval_set_file_path(
|
|
app_name, agent_dir, eval_set_id
|
|
)
|
|
with open(eval_set_file_path, "r") as file:
|
|
eval_set_data = json.load(file) # Load JSON into a list
|
|
|
|
return sorted([x["name"] for x in eval_set_data])
|
|
|
|
@app.post(
|
|
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
|
|
response_model_exclude_none=True,
|
|
)
|
|
async def run_eval(
|
|
app_name: str, eval_set_id: str, req: RunEvalRequest
|
|
) -> list[RunEvalResult]:
|
|
from .cli_eval import run_evals
|
|
|
|
"""Runs an eval given the details in the eval request."""
|
|
# Create a mapping from eval set file to all the evals that needed to be
|
|
# run.
|
|
eval_set_file_path = _get_eval_set_file_path(
|
|
app_name, agent_dir, eval_set_id
|
|
)
|
|
eval_set_to_evals = {eval_set_file_path: req.eval_ids}
|
|
|
|
if not req.eval_ids:
|
|
logger.info(
|
|
"Eval ids to run list is empty. We will all evals in the eval set."
|
|
)
|
|
root_agent = await _get_root_agent_async(app_name)
|
|
return [
|
|
RunEvalResult(
|
|
app_name=app_name,
|
|
eval_set_id=eval_set_id,
|
|
eval_id=eval_result.eval_id,
|
|
final_eval_status=eval_result.final_eval_status,
|
|
eval_metric_results=eval_result.eval_metric_results,
|
|
session_id=eval_result.session_id,
|
|
)
|
|
async for eval_result in run_evals(
|
|
eval_set_to_evals,
|
|
root_agent,
|
|
getattr(root_agent, "reset_data", None),
|
|
req.eval_metrics,
|
|
session_service=session_service,
|
|
artifact_service=artifact_service,
|
|
)
|
|
]
|
|
|
|
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
|
|
def delete_session(app_name: str, user_id: str, session_id: str):
|
|
# Connect to managed session if agent_engine_id is set.
|
|
app_name = agent_engine_id if agent_engine_id else app_name
|
|
session_service.delete_session(
|
|
app_name=app_name, user_id=user_id, session_id=session_id
|
|
)
|
|
|
|
@app.get(
|
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
|
|
response_model_exclude_none=True,
|
|
)
|
|
async def load_artifact(
|
|
app_name: str,
|
|
user_id: str,
|
|
session_id: str,
|
|
artifact_name: str,
|
|
version: Optional[int] = Query(None),
|
|
) -> Optional[types.Part]:
|
|
app_name = agent_engine_id if agent_engine_id else app_name
|
|
artifact = await artifact_service.load_artifact(
|
|
app_name=app_name,
|
|
user_id=user_id,
|
|
session_id=session_id,
|
|
filename=artifact_name,
|
|
version=version,
|
|
)
|
|
if not artifact:
|
|
raise HTTPException(status_code=404, detail="Artifact not found")
|
|
return artifact
|
|
|
|
@app.get(
|
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}",
|
|
response_model_exclude_none=True,
|
|
)
|
|
async def load_artifact_version(
|
|
app_name: str,
|
|
user_id: str,
|
|
session_id: str,
|
|
artifact_name: str,
|
|
version_id: int,
|
|
) -> Optional[types.Part]:
|
|
app_name = agent_engine_id if agent_engine_id else app_name
|
|
artifact = await artifact_service.load_artifact(
|
|
app_name=app_name,
|
|
user_id=user_id,
|
|
session_id=session_id,
|
|
filename=artifact_name,
|
|
version=version_id,
|
|
)
|
|
if not artifact:
|
|
raise HTTPException(status_code=404, detail="Artifact not found")
|
|
return artifact
|
|
|
|
@app.get(
|
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
|
|
response_model_exclude_none=True,
|
|
)
|
|
async def list_artifact_names(
|
|
app_name: str, user_id: str, session_id: str
|
|
) -> list[str]:
|
|
app_name = agent_engine_id if agent_engine_id else app_name
|
|
return await artifact_service.list_artifact_keys(
|
|
app_name=app_name, user_id=user_id, session_id=session_id
|
|
)
|
|
|
|
@app.get(
|
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions",
|
|
response_model_exclude_none=True,
|
|
)
|
|
async def list_artifact_versions(
|
|
app_name: str, user_id: str, session_id: str, artifact_name: str
|
|
) -> list[int]:
|
|
app_name = agent_engine_id if agent_engine_id else app_name
|
|
return await artifact_service.list_versions(
|
|
app_name=app_name,
|
|
user_id=user_id,
|
|
session_id=session_id,
|
|
filename=artifact_name,
|
|
)
|
|
|
|
@app.delete(
|
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
|
|
)
|
|
async def delete_artifact(
|
|
app_name: str, user_id: str, session_id: str, artifact_name: str
|
|
):
|
|
app_name = agent_engine_id if agent_engine_id else app_name
|
|
await artifact_service.delete_artifact(
|
|
app_name=app_name,
|
|
user_id=user_id,
|
|
session_id=session_id,
|
|
filename=artifact_name,
|
|
)
|
|
|
|
@app.post("/run", response_model_exclude_none=True)
|
|
async def agent_run(req: AgentRunRequest) -> list[Event]:
|
|
# Connect to managed session if agent_engine_id is set.
|
|
app_id = agent_engine_id if agent_engine_id else req.app_name
|
|
session = session_service.get_session(
|
|
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
|
)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
runner = await _get_runner_async(req.app_name)
|
|
events = [
|
|
event
|
|
async for event in runner.run_async(
|
|
user_id=req.user_id,
|
|
session_id=req.session_id,
|
|
new_message=req.new_message,
|
|
)
|
|
]
|
|
logger.info("Generated %s events in agent run: %s", len(events), events)
|
|
return events
|
|
|
|
@app.post("/run_sse")
|
|
async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
|
|
# Connect to managed session if agent_engine_id is set.
|
|
app_id = agent_engine_id if agent_engine_id else req.app_name
|
|
# SSE endpoint
|
|
session = session_service.get_session(
|
|
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
|
)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
# Convert the events to properly formatted SSE
|
|
async def event_generator():
|
|
try:
|
|
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
|
|
runner = await _get_runner_async(req.app_name)
|
|
async for event in runner.run_async(
|
|
user_id=req.user_id,
|
|
session_id=req.session_id,
|
|
new_message=req.new_message,
|
|
run_config=RunConfig(streaming_mode=stream_mode),
|
|
):
|
|
# Format as SSE data
|
|
sse_event = event.model_dump_json(exclude_none=True, by_alias=True)
|
|
logger.info("Generated event in agent run streaming: %s", sse_event)
|
|
yield f"data: {sse_event}\n\n"
|
|
except Exception as e:
|
|
logger.exception("Error in event_generator: %s", e)
|
|
# You might want to yield an error event here
|
|
yield f'data: {{"error": "{str(e)}"}}\n\n'
|
|
|
|
# Returns a streaming response with the proper media type for SSE
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
)
|
|
|
|
@app.get(
|
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
|
|
response_model_exclude_none=True,
|
|
)
|
|
async def get_event_graph(
|
|
app_name: str, user_id: str, session_id: str, event_id: str
|
|
):
|
|
# Connect to managed session if agent_engine_id is set.
|
|
app_id = agent_engine_id if agent_engine_id else app_name
|
|
session = session_service.get_session(
|
|
app_name=app_id, user_id=user_id, session_id=session_id
|
|
)
|
|
session_events = session.events if session else []
|
|
event = next((x for x in session_events if x.id == event_id), None)
|
|
if not event:
|
|
return {}
|
|
|
|
from . import agent_graph
|
|
|
|
function_calls = event.get_function_calls()
|
|
function_responses = event.get_function_responses()
|
|
root_agent = await _get_root_agent_async(app_name)
|
|
dot_graph = None
|
|
if function_calls:
|
|
function_call_highlights = []
|
|
for function_call in function_calls:
|
|
from_name = event.author
|
|
to_name = function_call.name
|
|
function_call_highlights.append((from_name, to_name))
|
|
dot_graph = await agent_graph.get_agent_graph(
|
|
root_agent, function_call_highlights
|
|
)
|
|
elif function_responses:
|
|
function_responses_highlights = []
|
|
for function_response in function_responses:
|
|
from_name = function_response.name
|
|
to_name = event.author
|
|
function_responses_highlights.append((from_name, to_name))
|
|
dot_graph = await agent_graph.get_agent_graph(
|
|
root_agent, function_responses_highlights
|
|
)
|
|
else:
|
|
from_name = event.author
|
|
to_name = ""
|
|
dot_graph = await agent_graph.get_agent_graph(
|
|
root_agent, [(from_name, to_name)]
|
|
)
|
|
if dot_graph and isinstance(dot_graph, graphviz.Digraph):
|
|
return {"dot_src": dot_graph.source}
|
|
else:
|
|
return {}
|
|
|
|
@app.websocket("/run_live")
|
|
async def agent_live_run(
|
|
websocket: WebSocket,
|
|
app_name: str,
|
|
user_id: str,
|
|
session_id: str,
|
|
modalities: List[Literal["TEXT", "AUDIO"]] = Query(
|
|
default=["TEXT", "AUDIO"]
|
|
), # Only allows "TEXT" or "AUDIO"
|
|
) -> None:
|
|
await websocket.accept()
|
|
|
|
# Connect to managed session if agent_engine_id is set.
|
|
app_id = agent_engine_id if agent_engine_id else app_name
|
|
session = session_service.get_session(
|
|
app_name=app_id, user_id=user_id, session_id=session_id
|
|
)
|
|
if not session:
|
|
# Accept first so that the client is aware of connection establishment,
|
|
# then close with a specific code.
|
|
await websocket.close(code=1002, reason="Session not found")
|
|
return
|
|
|
|
live_request_queue = LiveRequestQueue()
|
|
|
|
async def forward_events():
|
|
runner = await _get_runner_async(app_name)
|
|
async for event in runner.run_live(
|
|
session=session, live_request_queue=live_request_queue
|
|
):
|
|
await websocket.send_text(
|
|
event.model_dump_json(exclude_none=True, by_alias=True)
|
|
)
|
|
|
|
async def process_messages():
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_text()
|
|
# Validate and send the received message to the live queue.
|
|
live_request_queue.send(LiveRequest.model_validate_json(data))
|
|
except ValidationError as ve:
|
|
logger.error("Validation error in process_messages: %s", ve)
|
|
|
|
# Run both tasks concurrently and cancel all if one fails.
|
|
tasks = [
|
|
asyncio.create_task(forward_events()),
|
|
asyncio.create_task(process_messages()),
|
|
]
|
|
done, pending = await asyncio.wait(
|
|
tasks, return_when=asyncio.FIRST_EXCEPTION
|
|
)
|
|
try:
|
|
# This will re-raise any exception from the completed tasks.
|
|
for task in done:
|
|
task.result()
|
|
except WebSocketDisconnect:
|
|
logger.info("Client disconnected during process_messages.")
|
|
except Exception as e:
|
|
logger.exception("Error during live websocket communication: %s", e)
|
|
traceback.print_exc()
|
|
WEBSOCKET_INTERNAL_ERROR_CODE = 1011
|
|
WEBSOCKET_MAX_BYTES_FOR_REASON = 123
|
|
await websocket.close(
|
|
code=WEBSOCKET_INTERNAL_ERROR_CODE,
|
|
reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON],
|
|
)
|
|
finally:
|
|
for task in pending:
|
|
task.cancel()
|
|
|
|
def _get_all_toolsets(agent: BaseAgent) -> set[BaseToolset]:
|
|
toolsets = set()
|
|
if isinstance(agent, LlmAgent):
|
|
for tool_union in agent.tools:
|
|
if isinstance(tool_union, BaseToolset):
|
|
toolsets.add(tool_union)
|
|
for sub_agent in agent.sub_agents:
|
|
toolsets.update(_get_all_toolsets(sub_agent))
|
|
return toolsets
|
|
|
|
async def _get_root_agent_async(app_name: str) -> Agent:
|
|
"""Returns the root agent for the given app."""
|
|
if app_name in root_agent_dict:
|
|
return root_agent_dict[app_name]
|
|
agent_module = importlib.import_module(app_name)
|
|
if getattr(agent_module.agent, "root_agent"):
|
|
root_agent = agent_module.agent.root_agent
|
|
else:
|
|
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
|
|
|
|
# Handle an awaitable root agent and await for the actual agent.
|
|
if inspect.isawaitable(root_agent):
|
|
try:
|
|
agent, exit_stack = await root_agent
|
|
exit_stacks.append(exit_stack)
|
|
root_agent = agent
|
|
except Exception as e:
|
|
raise RuntimeError(f"error getting root agent, {e}") from e
|
|
|
|
root_agent_dict[app_name] = root_agent
|
|
toolsets_to_close.update(_get_all_toolsets(root_agent))
|
|
return root_agent
|
|
|
|
async def _get_runner_async(app_name: str) -> Runner:
|
|
"""Returns the runner for the given app."""
|
|
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
|
if app_name in runner_dict:
|
|
return runner_dict[app_name]
|
|
root_agent = await _get_root_agent_async(app_name)
|
|
runner = Runner(
|
|
app_name=agent_engine_id if agent_engine_id else app_name,
|
|
agent=root_agent,
|
|
artifact_service=artifact_service,
|
|
session_service=session_service,
|
|
memory_service=memory_service,
|
|
)
|
|
runner_dict[app_name] = runner
|
|
return runner
|
|
|
|
if web:
|
|
BASE_DIR = Path(__file__).parent.resolve()
|
|
ANGULAR_DIST_PATH = BASE_DIR / "browser"
|
|
|
|
@app.get("/")
|
|
async def redirect_to_dev_ui():
|
|
return RedirectResponse("/dev-ui")
|
|
|
|
@app.get("/dev-ui")
|
|
async def dev_ui():
|
|
return FileResponse(BASE_DIR / "browser/index.html")
|
|
|
|
app.mount(
|
|
"/", StaticFiles(directory=ANGULAR_DIST_PATH, html=True), name="static"
|
|
)
|
|
return app
|