mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2026-02-05 06:16:24 -06:00
Changes for 0.1.0 release
This commit is contained in:
@@ -31,7 +31,6 @@ import click
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.responses import RedirectResponse
|
||||
@@ -48,6 +47,7 @@ from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from starlette.types import Lifespan
|
||||
|
||||
from ..agents import RunConfig
|
||||
from ..agents.live_request_queue import LiveRequest
|
||||
@@ -83,7 +83,11 @@ class ApiServerSpanExporter(export.SpanExporter):
|
||||
self, spans: typing.Sequence[ReadableSpan]
|
||||
) -> export.SpanExportResult:
|
||||
for span in spans:
|
||||
if span.name == "call_llm" or span.name == "send_data":
|
||||
if (
|
||||
span.name == "call_llm"
|
||||
or span.name == "send_data"
|
||||
or span.name.startswith("tool_response")
|
||||
):
|
||||
attributes = dict(span.attributes)
|
||||
attributes["trace_id"] = span.get_span_context().trace_id
|
||||
attributes["span_id"] = span.get_span_context().span_id
|
||||
@@ -128,6 +132,8 @@ def get_fast_api_app(
|
||||
session_db_url: str = "",
|
||||
allow_origins: Optional[list[str]] = None,
|
||||
web: bool,
|
||||
trace_to_cloud: bool = False,
|
||||
lifespan: Optional[Lifespan[FastAPI]] = None,
|
||||
) -> FastAPI:
|
||||
# InMemory tracing dict.
|
||||
trace_dict: dict[str, Any] = {}
|
||||
@@ -137,18 +143,26 @@ def get_fast_api_app(
|
||||
provider.add_span_processor(
|
||||
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
|
||||
)
|
||||
if os.environ.get("ADK_TRACE_TO_CLOUD", "0") == "1":
|
||||
processor = export.BatchSpanProcessor(
|
||||
CloudTraceSpanExporter(
|
||||
project_id=os.environ.get("GOOGLE_CLOUD_PROJECT", "")
|
||||
)
|
||||
)
|
||||
provider.add_span_processor(processor)
|
||||
envs.load_dotenv()
|
||||
enable_cloud_tracing = trace_to_cloud or os.environ.get(
|
||||
"ADK_TRACE_TO_CLOUD", "0"
|
||||
).lower() in ["1", "true"]
|
||||
if enable_cloud_tracing:
|
||||
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
|
||||
processor = export.BatchSpanProcessor(
|
||||
CloudTraceSpanExporter(project_id=project_id)
|
||||
)
|
||||
provider.add_span_processor(processor)
|
||||
else:
|
||||
logging.warning(
|
||||
"GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will"
|
||||
" not be enabled."
|
||||
)
|
||||
|
||||
trace.set_tracer_provider(provider)
|
||||
|
||||
# Run the FastAPI server.
|
||||
app = FastAPI()
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
if allow_origins:
|
||||
app.add_middleware(
|
||||
@@ -478,6 +492,7 @@ def get_fast_api_app(
|
||||
artifact_name: str,
|
||||
version: Optional[int] = Query(None),
|
||||
) -> Optional[types.Part]:
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
artifact = artifact_service.load_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
@@ -500,6 +515,7 @@ def get_fast_api_app(
|
||||
artifact_name: str,
|
||||
version_id: int,
|
||||
) -> Optional[types.Part]:
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
artifact = artifact_service.load_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
@@ -518,6 +534,7 @@ def get_fast_api_app(
|
||||
def list_artifact_names(
|
||||
app_name: str, user_id: str, session_id: str
|
||||
) -> list[str]:
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
return artifact_service.list_artifact_keys(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
@@ -529,6 +546,7 @@ def get_fast_api_app(
|
||||
def list_artifact_versions(
|
||||
app_name: str, user_id: str, session_id: str, artifact_name: str
|
||||
) -> list[int]:
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
return artifact_service.list_versions(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
@@ -542,6 +560,7 @@ def get_fast_api_app(
|
||||
def delete_artifact(
|
||||
app_name: str, user_id: str, session_id: str, artifact_name: str
|
||||
):
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
artifact_service.delete_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
|
||||
Reference in New Issue
Block a user