Changes for 0.1.0 release

This commit is contained in:
hangfei
2025-04-09 04:24:34 +00:00
parent 9827820143
commit 363e10619a
25 changed files with 553 additions and 99 deletions
+29 -10
View File
@@ -31,7 +31,6 @@ import click
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Query
from fastapi import Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.responses import RedirectResponse
@@ -48,6 +47,7 @@ from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace import TracerProvider
from pydantic import BaseModel
from pydantic import ValidationError
from starlette.types import Lifespan
from ..agents import RunConfig
from ..agents.live_request_queue import LiveRequest
@@ -83,7 +83,11 @@ class ApiServerSpanExporter(export.SpanExporter):
self, spans: typing.Sequence[ReadableSpan]
) -> export.SpanExportResult:
for span in spans:
if span.name == "call_llm" or span.name == "send_data":
if (
span.name == "call_llm"
or span.name == "send_data"
or span.name.startswith("tool_response")
):
attributes = dict(span.attributes)
attributes["trace_id"] = span.get_span_context().trace_id
attributes["span_id"] = span.get_span_context().span_id
@@ -128,6 +132,8 @@ def get_fast_api_app(
session_db_url: str = "",
allow_origins: Optional[list[str]] = None,
web: bool,
trace_to_cloud: bool = False,
lifespan: Optional[Lifespan[FastAPI]] = None,
) -> FastAPI:
# InMemory tracing dict.
trace_dict: dict[str, Any] = {}
@@ -137,18 +143,26 @@ def get_fast_api_app(
provider.add_span_processor(
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
)
if os.environ.get("ADK_TRACE_TO_CLOUD", "0") == "1":
processor = export.BatchSpanProcessor(
CloudTraceSpanExporter(
project_id=os.environ.get("GOOGLE_CLOUD_PROJECT", "")
)
)
provider.add_span_processor(processor)
envs.load_dotenv()
enable_cloud_tracing = trace_to_cloud or os.environ.get(
"ADK_TRACE_TO_CLOUD", "0"
).lower() in ["1", "true"]
if enable_cloud_tracing:
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
processor = export.BatchSpanProcessor(
CloudTraceSpanExporter(project_id=project_id)
)
provider.add_span_processor(processor)
else:
logging.warning(
"GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will"
" not be enabled."
)
trace.set_tracer_provider(provider)
# Run the FastAPI server.
app = FastAPI()
app = FastAPI(lifespan=lifespan)
if allow_origins:
app.add_middleware(
@@ -478,6 +492,7 @@ def get_fast_api_app(
artifact_name: str,
version: Optional[int] = Query(None),
) -> Optional[types.Part]:
app_name = agent_engine_id if agent_engine_id else app_name
artifact = artifact_service.load_artifact(
app_name=app_name,
user_id=user_id,
@@ -500,6 +515,7 @@ def get_fast_api_app(
artifact_name: str,
version_id: int,
) -> Optional[types.Part]:
app_name = agent_engine_id if agent_engine_id else app_name
artifact = artifact_service.load_artifact(
app_name=app_name,
user_id=user_id,
@@ -518,6 +534,7 @@ def get_fast_api_app(
def list_artifact_names(
app_name: str, user_id: str, session_id: str
) -> list[str]:
app_name = agent_engine_id if agent_engine_id else app_name
return artifact_service.list_artifact_keys(
app_name=app_name, user_id=user_id, session_id=session_id
)
@@ -529,6 +546,7 @@ def get_fast_api_app(
def list_artifact_versions(
app_name: str, user_id: str, session_id: str, artifact_name: str
) -> list[int]:
app_name = agent_engine_id if agent_engine_id else app_name
return artifact_service.list_versions(
app_name=app_name,
user_id=user_id,
@@ -542,6 +560,7 @@ def get_fast_api_app(
def delete_artifact(
app_name: str, user_id: str, session_id: str, artifact_name: str
):
app_name = agent_engine_id if agent_engine_id else app_name
artifact_service.delete_artifact(
app_name=app_name,
user_id=user_id,