From 80813a75cf90412560631d42ecd8e797d0bb6251 Mon Sep 17 00:00:00 2001 From: Yifan Wang Date: Mon, 12 May 2025 17:56:20 -0700 Subject: [PATCH] Add debug trace endpoint in api server Details: - Add a in-memory SpanExporter to capture all trace information. - Add /debug/trace/session/{session_id} endpoint to retrieve traces from the in-memory exporter. - Add Session ID in Telemetry spans. PiperOrigin-RevId: 757984565 --- src/google/adk/cli/fast_api.py | 60 ++++++++++++++++++++++++++++++++-- src/google/adk/telemetry.py | 3 ++ 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index f898362..0404c83 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -48,6 +48,7 @@ from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter from opentelemetry.sdk.trace import export from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from pydantic import alias_generators from pydantic import BaseModel from pydantic import ConfigDict @@ -112,6 +113,42 @@ class ApiServerSpanExporter(export.SpanExporter): return True +class InMemoryExporter(export.SpanExporter): + + def __init__(self, trace_dict): + super().__init__() + self._spans = [] + self.trace_dict = trace_dict + + def export( + self, spans: typing.Sequence[ReadableSpan] + ) -> export.SpanExportResult: + for span in spans: + trace_id = span.context.trace_id + if span.name == "call_llm": + attributes = dict(span.attributes) + session_id = attributes.get("gcp.vertex.agent.session_id", None) + if session_id: + if session_id not in self.trace_dict: + self.trace_dict[session_id] = [trace_id] + else: + self.trace_dict[session_id] += [trace_id] + self._spans.extend(spans) + return export.SpanExportResult.SUCCESS + + def get_finished_spans(self, session_id: str): + trace_ids = self.trace_dict.get(session_id, None) + if trace_ids is None or not trace_ids: + return [] + return [x for x in self._spans if x.context.trace_id in trace_ids] + + def force_flush(self, timeout_millis: int = 30000) -> bool: + return True + + def clear(self): + self._spans.clear() + + class AgentRunRequest(BaseModel): app_name: str user_id: str @@ -152,12 +189,15 @@ def get_fast_api_app( ) -> FastAPI: # InMemory tracing dict. trace_dict: dict[str, Any] = {} + session_trace_dict: dict[str, Any] = {} # Set up tracing in the FastAPI server. provider = TracerProvider() provider.add_span_processor( export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict)) ) + memory_exporter = InMemoryExporter(session_trace_dict) + provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter)) if trace_to_cloud: envs.load_dotenv_for_agent("", agent_dir) if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None): @@ -254,6 +294,24 @@ def get_fast_api_app( raise HTTPException(status_code=404, detail="Trace not found") return event_dict + @app.get("/debug/trace/session/{session_id}") + def get_session_trace(session_id: str) -> Any: + spans = memory_exporter.get_finished_spans(session_id) + if not spans: + return [] + return [ + { + "name": s.name, + "span_id": s.context.span_id, + "trace_id": s.context.trace_id, + "start_time": s.start_time, + "end_time": s.end_time, + "attributes": dict(s.attributes), + "parent_span_id": s.parent.span_id if s.parent else None, + } + for s in spans + ] + @app.get( "/apps/{app_name}/users/{user_id}/sessions/{session_id}", response_model_exclude_none=True, @@ -306,7 +364,6 @@ def get_fast_api_app( raise HTTPException( status_code=400, detail=f"Session already exists: {session_id}" ) - logger.info("New session created: %s", session_id) return session_service.create_session( app_name=app_name, user_id=user_id, state=state, session_id=session_id @@ -323,7 +380,6 @@ def get_fast_api_app( ) -> Session: # Connect to managed session if agent_engine_id is set. app_name = agent_engine_id if agent_engine_id else app_name - logger.info("New session created") return session_service.create_session( app_name=app_name, user_id=user_id, state=state diff --git a/src/google/adk/telemetry.py b/src/google/adk/telemetry.py index 0ee6cf8..dd32b3b 100644 --- a/src/google/adk/telemetry.py +++ b/src/google/adk/telemetry.py @@ -111,6 +111,9 @@ def trace_call_llm( span.set_attribute( 'gcp.vertex.agent.invocation_id', invocation_context.invocation_id ) + span.set_attribute( + 'gcp.vertex.agent.session_id', invocation_context.session.id + ) span.set_attribute('gcp.vertex.agent.event_id', event_id) # Consider removing once GenAI SDK provides a way to record this info. span.set_attribute(