refactor(a2a): enhance A2A agent functionality with improved error handling, logging, and support for streaming requests

This commit is contained in:
Davidson Gomes 2025-05-05 21:46:24 -03:00
parent ec9dc07d71
commit 16e9747cce
2 changed files with 362 additions and 154 deletions

View File

@ -2,23 +2,28 @@ from google.adk.agents import BaseAgent
from google.adk.agents.invocation_context import InvocationContext
from google.adk.events import Event
from google.genai.types import Content, Part
from typing import AsyncGenerator
from typing import AsyncGenerator, Dict, Any, Optional
import json
import asyncio
import time
import logging
from src.schemas.a2a_types import (
GetTaskRequest,
SendTaskRequest,
SendTaskStreamingRequest,
Message,
TextPart,
TaskState,
JSONRPCRequest,
)
import httpx
from uuid import uuid4
logger = logging.getLogger(__name__)
class A2ACustomAgent(BaseAgent):
"""
@ -29,17 +34,22 @@ class A2ACustomAgent(BaseAgent):
# Field declarations for Pydantic
agent_card_url: str
api_key: Optional[str]
poll_interval: float
max_wait_time: int
timeout: int
streaming: bool
base_url: Optional[str] = None
def __init__(
self,
name: str,
agent_card_url: str,
api_key: Optional[str] = None,
poll_interval: float = 1.0,
max_wait_time: int = 60,
timeout: int = 300,
streaming: bool = True,
**kwargs,
):
"""
@ -48,21 +58,83 @@ class A2ACustomAgent(BaseAgent):
Args:
name: Agent name
agent_card_url: A2A agent card URL
api_key: API key for authentication
poll_interval: Status check interval (seconds)
max_wait_time: Maximum wait time for a task (seconds)
timeout: Maximum execution time (seconds)
streaming: Whether to use streaming mode
"""
# Initialize base class
# Get base URL by removing agent.json if present
derived_base_url = agent_card_url
if "/.well-known/agent.json" in derived_base_url:
derived_base_url = derived_base_url.split("/.well-known/agent.json")[0]
# Initialize base class with all fields including base_url
super().__init__(
name=name,
agent_card_url=agent_card_url,
api_key=api_key,
poll_interval=poll_interval,
max_wait_time=max_wait_time,
timeout=timeout,
streaming=streaming,
base_url=derived_base_url,
**kwargs,
)
print(f"A2A agent initialized for URL: {agent_card_url}")
logger.info(f"A2A agent initialized for URL: {agent_card_url}")
# Default headers
self.headers = {"Content-Type": "application/json"}
if api_key:
self.headers["x-api-key"] = api_key
async def _send_jsonrpc_request(self, request: JSONRPCRequest) -> Dict[str, Any]:
"""
Send a JSON-RPC request to the A2A endpoint.
Args:
request: The JSON-RPC request object
Returns:
Dict containing the response
"""
try:
async with httpx.AsyncClient() as client:
logger.debug(
f"Sending request to {self.base_url}: {request.model_dump()}"
)
response = await client.post(
self.base_url,
json=request.model_dump(),
headers=self.headers,
timeout=30,
)
response.raise_for_status()
response_data = response.json()
logger.debug(f"Received response: {response_data}")
# Check for JSON-RPC errors
if "error" in response_data and response_data["error"]:
error_msg = response_data["error"].get("message", "Unknown error")
error_code = response_data["error"].get("code", -1)
logger.error(f"JSON-RPC error {error_code}: {error_msg}")
raise ValueError(f"A2A server error: {error_msg}")
return response_data
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error: {e.response.status_code} - {e}")
raise ValueError(f"HTTP error {e.response.status_code}: {str(e)}")
except httpx.RequestError as e:
logger.error(f"Request error: {e}")
raise ValueError(f"Request error: {str(e)}")
except json.JSONDecodeError as e:
logger.error(f"JSON decode error: {e}")
raise ValueError(f"Invalid JSON response: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error: {e}")
raise ValueError(f"Error communicating with A2A server: {str(e)}")
async def _run_async_impl(
self, ctx: InvocationContext
@ -77,13 +149,6 @@ class A2ACustomAgent(BaseAgent):
yield Event(author=self.name)
try:
# Prepare the base URL for the A2A
url = self.agent_card_url
# Ensure that there is no /.well-known/agent.json in the url
if "/.well-known/agent.json" in url:
url = url.split("/.well-known/agent.json")[0]
# 2. Extract the user's message from the context
user_message = None
@ -92,7 +157,7 @@ class A2ACustomAgent(BaseAgent):
for event in reversed(ctx.session.events):
if event.author == "user" and event.content and event.content.parts:
user_message = event.content.parts[0].text
print("Message found in session events")
logger.info("Message found in session events")
break
# Check in the session state if the message was not found in the events
@ -102,8 +167,21 @@ class A2ACustomAgent(BaseAgent):
elif "message" in ctx.session.state:
user_message = ctx.session.state["message"]
if not user_message:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(text="Error: No message found to send to A2A agent")
],
),
)
yield Event(author=self.name) # Final event
return
# 3. Create and send the task to the A2A agent
print(f"Sending task to A2A agent: {user_message[:100]}...")
logger.info(f"Sending task to A2A agent: {user_message[:100]}...")
# Use the session ID as a stable identifier
session_id = (
@ -112,210 +190,331 @@ class A2ACustomAgent(BaseAgent):
else str(uuid4())
)
task_id = str(uuid4())
request_id = str(uuid4())
try:
formatted_message: Message = Message(
# Format message according to A2A protocol
formatted_message = Message(
role="user",
parts=[TextPart(type="text", text=user_message)],
)
request: SendTaskRequest = SendTaskRequest(
params={
"message": formatted_message,
"sessionId": session_id,
"id": task_id,
}
)
print(f"Request send task: {request.model_dump()}")
# REQUEST POST to url when jsonrpc is 2.0
task_result = await httpx.AsyncClient().post(
url, json=request.model_dump(), timeout=30
)
print(f"Task response: {task_result.json()}")
print(f"Task sent successfully, ID: {task_id}")
# Prepare standard params for A2A
task_params = {
"id": task_id,
"sessionId": session_id,
"message": formatted_message,
"acceptedOutputModes": ["text"],
}
# Emit processing message
yield Event(
author=self.name,
content=Content(
role="agent", parts=[Part(text="Processing request...")]
role="agent", parts=[Part(text="Processing your request...")]
),
)
except Exception as e:
error_msg = f"Error sending request: {str(e)}"
print(error_msg)
print(f"Error type: {type(e).__name__}")
print(f"Error details: {str(e)}")
yield Event(
author=self.name,
content=Content(role="agent", parts=[Part(text=str(e))]),
)
yield Event(author=self.name) # Final event
return
start_time = time.time()
while time.time() - start_time < self.timeout:
try:
# Check current status
request: GetTaskRequest = GetTaskRequest(params={"id": task_id})
task_status_response = await httpx.AsyncClient().post(
url, json=request.model_dump(), timeout=30
if self.streaming:
# Use streaming mode
request = SendTaskStreamingRequest(
id=request_id, params=task_params
)
print(f"Response get task: {task_status_response.json()}")
# Handle streaming response
accumulated_response = ""
task_status = task_status_response.json()
try:
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream(
"POST",
self.base_url,
json=request.model_dump(),
headers=self.headers,
) as response:
response.raise_for_status()
if "result" not in task_status or not task_status["result"]:
await asyncio.sleep(self.poll_interval)
continue
# Process SSE events
async for line in response.aiter_lines():
if not line or line.strip() == "":
continue
current_state = None
if (
"status" in task_status["result"]
and task_status["result"]["status"]
):
if "state" in task_status["result"]["status"]:
current_state = task_status["result"]["status"]["state"]
print(f"Task status {task_id}: {current_state}")
# Check if the task was completed
if current_state in [
TaskState.COMPLETED,
TaskState.FAILED,
TaskState.CANCELED,
]:
if current_state == TaskState.COMPLETED:
# Extract the response
if (
"status" in task_status["result"]
and "message" in task_status["result"]["status"]
and "parts"
in task_status["result"]["status"]["message"]
):
# Convert A2A parts to ADK
response_parts = []
for part in task_status["result"]["status"]["message"][
"parts"
]:
if "text" in part and part["text"]:
response_parts.append(Part(text=part["text"]))
elif "data" in part:
if line.startswith("data:"):
data = line[5:].strip()
try:
json_text = json.dumps(
part["data"],
ensure_ascii=False,
indent=2,
)
response_parts.append(
Part(text=f"```json\n{json_text}\n```")
)
except Exception:
response_parts.append(
Part(text="[Unserializable data]")
event_data = json.loads(data)
logger.debug(f"SSE event: {event_data}")
# Process artifacts
if (
"result" in event_data
and "artifact" in event_data["result"]
):
artifact = event_data["result"][
"artifact"
]
if artifact and "parts" in artifact:
for part in artifact["parts"]:
if (
"text" in part
and part["text"]
):
accumulated_response += (
part["text"]
)
# Emit incremental update
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(
text=accumulated_response
)
],
),
)
# Check if task is complete
if (
"result" in event_data
and "status" in event_data["result"]
and "final" in event_data["result"]
and event_data["result"]["final"]
):
logger.info(
"Task completed in streaming mode"
)
break
except json.JSONDecodeError as e:
logger.error(
f"Error parsing SSE event: {e}"
)
if response_parts:
except Exception as e:
logger.error(f"Error in streaming mode: {e}")
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[Part(text=f"Error in streaming mode: {str(e)}")],
),
)
# If we have a response, we're done
if accumulated_response:
yield Event(author=self.name) # Final event
return
else:
# Use non-streaming mode
request = SendTaskRequest(id=request_id, params=task_params)
# Make the request
response_data = await self._send_jsonrpc_request(request)
# Process the response
if "result" in response_data and response_data["result"]:
task_result = response_data["result"]
# Extract message from response
if (
"status" in task_result
and "message" in task_result["status"]
and "parts" in task_result["status"]["message"]
):
parts = task_result["status"]["message"]["parts"]
for part in parts:
if "text" in part and part["text"]:
yield Event(
author=self.name,
content=Content(
role="agent", parts=response_parts
role="agent",
parts=[Part(text=part["text"])],
),
)
yield Event(author=self.name) # Final event
return
# If we reach here, we need to poll for results
start_time = time.time()
while time.time() - start_time < self.timeout:
try:
# Check current status
status_request = GetTaskRequest(
id=request_id, params={"id": task_id, "historyLength": 10}
)
task_status = await self._send_jsonrpc_request(status_request)
if "result" not in task_status or not task_status["result"]:
await asyncio.sleep(self.poll_interval)
continue
current_state = None
if (
"status" in task_status["result"]
and task_status["result"]["status"]
):
if "state" in task_status["result"]["status"]:
current_state = task_status["result"]["status"]["state"]
logger.info(f"Task status {task_id}: {current_state}")
# Check if the task was completed
if current_state in [
TaskState.COMPLETED,
TaskState.FAILED,
TaskState.CANCELED,
]:
if current_state == TaskState.COMPLETED:
# Extract the response
if (
"status" in task_status["result"]
and "message" in task_status["result"]["status"]
and "parts"
in task_status["result"]["status"]["message"]
):
# Convert A2A parts to ADK
response_parts = []
for part in task_status["result"]["status"][
"message"
]["parts"]:
if "text" in part and part["text"]:
response_parts.append(
Part(text=part["text"])
)
elif "data" in part:
try:
json_text = json.dumps(
part["data"],
ensure_ascii=False,
indent=2,
)
response_parts.append(
Part(
text=f"```json\n{json_text}\n```"
)
)
except Exception:
response_parts.append(
Part(text="[Unserializable data]")
)
if response_parts:
yield Event(
author=self.name,
content=Content(
role="agent", parts=response_parts
),
)
else:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(
text="Empty response from agent."
)
],
),
)
else:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(text="Empty response from agent.")
Part(
text="Task completed, but no response message."
)
],
),
)
else:
elif current_state == TaskState.FAILED:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(
text="Task completed, but no response message."
text="The task failed during processing."
)
],
),
)
elif current_state == TaskState.FAILED:
else: # CANCELED
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[Part(text="The task was canceled.")],
),
)
# Store in the session state for future reference
if ctx.session:
try:
ctx.session.state["a2a_task_result"] = task_status[
"result"
]
except Exception:
pass
break # Exit the loop of checking
except Exception as e:
logger.error(f"Error checking task: {str(e)}")
# If the timeout was exceeded, inform the user
if time.time() - start_time > self.max_wait_time:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(text="The task failed during processing.")
],
),
)
else: # CANCELED
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[Part(text="The task was canceled.")],
parts=[Part(text=f"Error checking task: {str(e)}")],
),
)
break
# Store in the session state for future reference
if ctx.session:
try:
ctx.session.state["a2a_task_result"] = task_status[
"result"
]
except Exception:
pass
# Wait before the next check
await asyncio.sleep(self.poll_interval)
break # Exit the loop of checking
# If the timeout was exceeded
if time.time() - start_time >= self.timeout:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(
text="The operation exceeded the timeout. Please try again later."
)
],
),
)
except Exception as e:
print(f"Error checking task: {str(e)}")
# If the timeout was exceeded, inform the user
if time.time() - start_time > self.max_wait_time:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[Part(text=f"Error checking task: {str(e)}")],
),
)
break
# Wait before the next check
await asyncio.sleep(self.poll_interval)
# If the timeout was exceeded
if time.time() - start_time >= self.timeout:
except Exception as e:
logger.error(f"Error sending task: {str(e)}")
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(
text="The operation exceeded the timeout. Please try again later."
)
Part(text=f"Error interacting with A2A agent: {str(e)}")
],
),
)
except Exception as e:
# Handle any uncaught error
print(f"Error executing A2A agent: {str(e)}")
logger.error(f"Error executing A2A agent: {str(e)}")
yield Event(
author=self.name,
content=Content(

View File

@ -55,6 +55,8 @@ from src.schemas.a2a_types import (
AgentCard,
AgentCapabilities,
AgentSkill,
AgentAuthentication,
AgentProvider,
)
logger = logging.getLogger(__name__)
@ -558,7 +560,6 @@ class A2AService:
for tool_name in mcp_tools:
logger.info(f"Processing tool: {tool_name}")
# Buscar informações da ferramenta pelo ID
tool_info = None
if hasattr(mcp_server, "tools") and isinstance(
mcp_server.tools, list
@ -633,10 +634,18 @@ class A2AService:
name=agent.name,
description=agent.description or "",
url=f"{settings.API_URL}/api/v1/a2a/{agent_id}",
version="1.0.0",
provider=AgentProvider(
organization=settings.ORGANIZATION_NAME,
url=settings.ORGANIZATION_URL,
),
version=f"{settings.API_VERSION}",
capabilities=capabilities,
authentication=AgentAuthentication(
schemes=["apiKey"],
credentials="x-api-key",
),
defaultInputModes=["text"],
defaultOutputModes=["text"],
capabilities=capabilities,
skills=skills,
)