Add debug trace endpoint in api server

Details:
- Add a in-memory SpanExporter to capture all trace information.
- Add /debug/trace/session/{session_id} endpoint to retrieve traces from the in-memory exporter.
- Add Session ID in Telemetry spans.

PiperOrigin-RevId: 757984565
This commit is contained in:
Yifan Wang 2025-05-12 17:56:20 -07:00 committed by Copybara-Service
parent d35b99e6dd
commit 80813a75cf
2 changed files with 61 additions and 2 deletions

View File

@ -48,6 +48,7 @@ from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
from opentelemetry.sdk.trace import export
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from pydantic import alias_generators
from pydantic import BaseModel
from pydantic import ConfigDict
@ -112,6 +113,42 @@ class ApiServerSpanExporter(export.SpanExporter):
return True
class InMemoryExporter(export.SpanExporter):
def __init__(self, trace_dict):
super().__init__()
self._spans = []
self.trace_dict = trace_dict
def export(
self, spans: typing.Sequence[ReadableSpan]
) -> export.SpanExportResult:
for span in spans:
trace_id = span.context.trace_id
if span.name == "call_llm":
attributes = dict(span.attributes)
session_id = attributes.get("gcp.vertex.agent.session_id", None)
if session_id:
if session_id not in self.trace_dict:
self.trace_dict[session_id] = [trace_id]
else:
self.trace_dict[session_id] += [trace_id]
self._spans.extend(spans)
return export.SpanExportResult.SUCCESS
def get_finished_spans(self, session_id: str):
trace_ids = self.trace_dict.get(session_id, None)
if trace_ids is None or not trace_ids:
return []
return [x for x in self._spans if x.context.trace_id in trace_ids]
def force_flush(self, timeout_millis: int = 30000) -> bool:
return True
def clear(self):
self._spans.clear()
class AgentRunRequest(BaseModel):
app_name: str
user_id: str
@ -152,12 +189,15 @@ def get_fast_api_app(
) -> FastAPI:
# InMemory tracing dict.
trace_dict: dict[str, Any] = {}
session_trace_dict: dict[str, Any] = {}
# Set up tracing in the FastAPI server.
provider = TracerProvider()
provider.add_span_processor(
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
)
memory_exporter = InMemoryExporter(session_trace_dict)
provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter))
if trace_to_cloud:
envs.load_dotenv_for_agent("", agent_dir)
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
@ -254,6 +294,24 @@ def get_fast_api_app(
raise HTTPException(status_code=404, detail="Trace not found")
return event_dict
@app.get("/debug/trace/session/{session_id}")
def get_session_trace(session_id: str) -> Any:
spans = memory_exporter.get_finished_spans(session_id)
if not spans:
return []
return [
{
"name": s.name,
"span_id": s.context.span_id,
"trace_id": s.context.trace_id,
"start_time": s.start_time,
"end_time": s.end_time,
"attributes": dict(s.attributes),
"parent_span_id": s.parent.span_id if s.parent else None,
}
for s in spans
]
@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
response_model_exclude_none=True,
@ -306,7 +364,6 @@ def get_fast_api_app(
raise HTTPException(
status_code=400, detail=f"Session already exists: {session_id}"
)
logger.info("New session created: %s", session_id)
return session_service.create_session(
app_name=app_name, user_id=user_id, state=state, session_id=session_id
@ -323,7 +380,6 @@ def get_fast_api_app(
) -> Session:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
logger.info("New session created")
return session_service.create_session(
app_name=app_name, user_id=user_id, state=state

View File

@ -111,6 +111,9 @@ def trace_call_llm(
span.set_attribute(
'gcp.vertex.agent.invocation_id', invocation_context.invocation_id
)
span.set_attribute(
'gcp.vertex.agent.session_id', invocation_context.session.id
)
span.set_attribute('gcp.vertex.agent.event_id', event_id)
# Consider removing once GenAI SDK provides a way to record this info.
span.set_attribute(