refactor(a2a): enhance A2A agent functionality with improved error handling, logging, and support for streaming requests

This commit is contained in:
Davidson Gomes 2025-05-05 21:46:24 -03:00
parent ec9dc07d71
commit 16e9747cce
2 changed files with 362 additions and 154 deletions

View File

@ -2,23 +2,28 @@ from google.adk.agents import BaseAgent
from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.invocation_context import InvocationContext
from google.adk.events import Event from google.adk.events import Event
from google.genai.types import Content, Part from google.genai.types import Content, Part
from typing import AsyncGenerator from typing import AsyncGenerator, Dict, Any, Optional
import json import json
import asyncio import asyncio
import time import time
import logging
from src.schemas.a2a_types import ( from src.schemas.a2a_types import (
GetTaskRequest, GetTaskRequest,
SendTaskRequest, SendTaskRequest,
SendTaskStreamingRequest,
Message, Message,
TextPart, TextPart,
TaskState, TaskState,
JSONRPCRequest,
) )
import httpx import httpx
from uuid import uuid4 from uuid import uuid4
logger = logging.getLogger(__name__)
class A2ACustomAgent(BaseAgent): class A2ACustomAgent(BaseAgent):
""" """
@ -29,17 +34,22 @@ class A2ACustomAgent(BaseAgent):
# Field declarations for Pydantic # Field declarations for Pydantic
agent_card_url: str agent_card_url: str
api_key: Optional[str]
poll_interval: float poll_interval: float
max_wait_time: int max_wait_time: int
timeout: int timeout: int
streaming: bool
base_url: Optional[str] = None
def __init__( def __init__(
self, self,
name: str, name: str,
agent_card_url: str, agent_card_url: str,
api_key: Optional[str] = None,
poll_interval: float = 1.0, poll_interval: float = 1.0,
max_wait_time: int = 60, max_wait_time: int = 60,
timeout: int = 300, timeout: int = 300,
streaming: bool = True,
**kwargs, **kwargs,
): ):
""" """
@ -48,21 +58,83 @@ class A2ACustomAgent(BaseAgent):
Args: Args:
name: Agent name name: Agent name
agent_card_url: A2A agent card URL agent_card_url: A2A agent card URL
api_key: API key for authentication
poll_interval: Status check interval (seconds) poll_interval: Status check interval (seconds)
max_wait_time: Maximum wait time for a task (seconds) max_wait_time: Maximum wait time for a task (seconds)
timeout: Maximum execution time (seconds) timeout: Maximum execution time (seconds)
streaming: Whether to use streaming mode
""" """
# Initialize base class # Get base URL by removing agent.json if present
derived_base_url = agent_card_url
if "/.well-known/agent.json" in derived_base_url:
derived_base_url = derived_base_url.split("/.well-known/agent.json")[0]
# Initialize base class with all fields including base_url
super().__init__( super().__init__(
name=name, name=name,
agent_card_url=agent_card_url, agent_card_url=agent_card_url,
api_key=api_key,
poll_interval=poll_interval, poll_interval=poll_interval,
max_wait_time=max_wait_time, max_wait_time=max_wait_time,
timeout=timeout, timeout=timeout,
streaming=streaming,
base_url=derived_base_url,
**kwargs, **kwargs,
) )
print(f"A2A agent initialized for URL: {agent_card_url}") logger.info(f"A2A agent initialized for URL: {agent_card_url}")
# Default headers
self.headers = {"Content-Type": "application/json"}
if api_key:
self.headers["x-api-key"] = api_key
async def _send_jsonrpc_request(self, request: JSONRPCRequest) -> Dict[str, Any]:
"""
Send a JSON-RPC request to the A2A endpoint.
Args:
request: The JSON-RPC request object
Returns:
Dict containing the response
"""
try:
async with httpx.AsyncClient() as client:
logger.debug(
f"Sending request to {self.base_url}: {request.model_dump()}"
)
response = await client.post(
self.base_url,
json=request.model_dump(),
headers=self.headers,
timeout=30,
)
response.raise_for_status()
response_data = response.json()
logger.debug(f"Received response: {response_data}")
# Check for JSON-RPC errors
if "error" in response_data and response_data["error"]:
error_msg = response_data["error"].get("message", "Unknown error")
error_code = response_data["error"].get("code", -1)
logger.error(f"JSON-RPC error {error_code}: {error_msg}")
raise ValueError(f"A2A server error: {error_msg}")
return response_data
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error: {e.response.status_code} - {e}")
raise ValueError(f"HTTP error {e.response.status_code}: {str(e)}")
except httpx.RequestError as e:
logger.error(f"Request error: {e}")
raise ValueError(f"Request error: {str(e)}")
except json.JSONDecodeError as e:
logger.error(f"JSON decode error: {e}")
raise ValueError(f"Invalid JSON response: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error: {e}")
raise ValueError(f"Error communicating with A2A server: {str(e)}")
async def _run_async_impl( async def _run_async_impl(
self, ctx: InvocationContext self, ctx: InvocationContext
@ -77,13 +149,6 @@ class A2ACustomAgent(BaseAgent):
yield Event(author=self.name) yield Event(author=self.name)
try: try:
# Prepare the base URL for the A2A
url = self.agent_card_url
# Ensure that there is no /.well-known/agent.json in the url
if "/.well-known/agent.json" in url:
url = url.split("/.well-known/agent.json")[0]
# 2. Extract the user's message from the context # 2. Extract the user's message from the context
user_message = None user_message = None
@ -92,7 +157,7 @@ class A2ACustomAgent(BaseAgent):
for event in reversed(ctx.session.events): for event in reversed(ctx.session.events):
if event.author == "user" and event.content and event.content.parts: if event.author == "user" and event.content and event.content.parts:
user_message = event.content.parts[0].text user_message = event.content.parts[0].text
print("Message found in session events") logger.info("Message found in session events")
break break
# Check in the session state if the message was not found in the events # Check in the session state if the message was not found in the events
@ -102,8 +167,21 @@ class A2ACustomAgent(BaseAgent):
elif "message" in ctx.session.state: elif "message" in ctx.session.state:
user_message = ctx.session.state["message"] user_message = ctx.session.state["message"]
if not user_message:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(text="Error: No message found to send to A2A agent")
],
),
)
yield Event(author=self.name) # Final event
return
# 3. Create and send the task to the A2A agent # 3. Create and send the task to the A2A agent
print(f"Sending task to A2A agent: {user_message[:100]}...") logger.info(f"Sending task to A2A agent: {user_message[:100]}...")
# Use the session ID as a stable identifier # Use the session ID as a stable identifier
session_id = ( session_id = (
@ -112,65 +190,166 @@ class A2ACustomAgent(BaseAgent):
else str(uuid4()) else str(uuid4())
) )
task_id = str(uuid4()) task_id = str(uuid4())
request_id = str(uuid4())
try: try:
# Format message according to A2A protocol
formatted_message: Message = Message( formatted_message = Message(
role="user", role="user",
parts=[TextPart(type="text", text=user_message)], parts=[TextPart(type="text", text=user_message)],
) )
request: SendTaskRequest = SendTaskRequest( # Prepare standard params for A2A
params={ task_params = {
"message": formatted_message,
"sessionId": session_id,
"id": task_id, "id": task_id,
"sessionId": session_id,
"message": formatted_message,
"acceptedOutputModes": ["text"],
} }
)
print(f"Request send task: {request.model_dump()}")
# REQUEST POST to url when jsonrpc is 2.0
task_result = await httpx.AsyncClient().post(
url, json=request.model_dump(), timeout=30
)
print(f"Task response: {task_result.json()}")
print(f"Task sent successfully, ID: {task_id}")
# Emit processing message
yield Event( yield Event(
author=self.name, author=self.name,
content=Content( content=Content(
role="agent", parts=[Part(text="Processing request...")] role="agent", parts=[Part(text="Processing your request...")]
), ),
) )
except Exception as e:
error_msg = f"Error sending request: {str(e)}"
print(error_msg)
print(f"Error type: {type(e).__name__}")
print(f"Error details: {str(e)}")
if self.streaming:
# Use streaming mode
request = SendTaskStreamingRequest(
id=request_id, params=task_params
)
# Handle streaming response
accumulated_response = ""
try:
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream(
"POST",
self.base_url,
json=request.model_dump(),
headers=self.headers,
) as response:
response.raise_for_status()
# Process SSE events
async for line in response.aiter_lines():
if not line or line.strip() == "":
continue
if line.startswith("data:"):
data = line[5:].strip()
try:
event_data = json.loads(data)
logger.debug(f"SSE event: {event_data}")
# Process artifacts
if (
"result" in event_data
and "artifact" in event_data["result"]
):
artifact = event_data["result"][
"artifact"
]
if artifact and "parts" in artifact:
for part in artifact["parts"]:
if (
"text" in part
and part["text"]
):
accumulated_response += (
part["text"]
)
# Emit incremental update
yield Event( yield Event(
author=self.name, author=self.name,
content=Content(role="agent", parts=[Part(text=str(e))]), content=Content(
role="agent",
parts=[
Part(
text=accumulated_response
)
],
),
)
# Check if task is complete
if (
"result" in event_data
and "status" in event_data["result"]
and "final" in event_data["result"]
and event_data["result"]["final"]
):
logger.info(
"Task completed in streaming mode"
)
break
except json.JSONDecodeError as e:
logger.error(
f"Error parsing SSE event: {e}"
)
except Exception as e:
logger.error(f"Error in streaming mode: {e}")
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[Part(text=f"Error in streaming mode: {str(e)}")],
),
)
# If we have a response, we're done
if accumulated_response:
yield Event(author=self.name) # Final event
return
else:
# Use non-streaming mode
request = SendTaskRequest(id=request_id, params=task_params)
# Make the request
response_data = await self._send_jsonrpc_request(request)
# Process the response
if "result" in response_data and response_data["result"]:
task_result = response_data["result"]
# Extract message from response
if (
"status" in task_result
and "message" in task_result["status"]
and "parts" in task_result["status"]["message"]
):
parts = task_result["status"]["message"]["parts"]
for part in parts:
if "text" in part and part["text"]:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[Part(text=part["text"])],
),
) )
yield Event(author=self.name) # Final event yield Event(author=self.name) # Final event
return return
# If we reach here, we need to poll for results
start_time = time.time() start_time = time.time()
while time.time() - start_time < self.timeout: while time.time() - start_time < self.timeout:
try: try:
# Check current status # Check current status
request: GetTaskRequest = GetTaskRequest(params={"id": task_id}) status_request = GetTaskRequest(
id=request_id, params={"id": task_id, "historyLength": 10}
task_status_response = await httpx.AsyncClient().post(
url, json=request.model_dump(), timeout=30
) )
print(f"Response get task: {task_status_response.json()}") task_status = await self._send_jsonrpc_request(status_request)
task_status = task_status_response.json()
if "result" not in task_status or not task_status["result"]: if "result" not in task_status or not task_status["result"]:
await asyncio.sleep(self.poll_interval) await asyncio.sleep(self.poll_interval)
@ -184,7 +363,7 @@ class A2ACustomAgent(BaseAgent):
if "state" in task_status["result"]["status"]: if "state" in task_status["result"]["status"]:
current_state = task_status["result"]["status"]["state"] current_state = task_status["result"]["status"]["state"]
print(f"Task status {task_id}: {current_state}") logger.info(f"Task status {task_id}: {current_state}")
# Check if the task was completed # Check if the task was completed
if current_state in [ if current_state in [
@ -203,11 +382,13 @@ class A2ACustomAgent(BaseAgent):
# Convert A2A parts to ADK # Convert A2A parts to ADK
response_parts = [] response_parts = []
for part in task_status["result"]["status"]["message"][ for part in task_status["result"]["status"][
"parts" "message"
]: ]["parts"]:
if "text" in part and part["text"]: if "text" in part and part["text"]:
response_parts.append(Part(text=part["text"])) response_parts.append(
Part(text=part["text"])
)
elif "data" in part: elif "data" in part:
try: try:
json_text = json.dumps( json_text = json.dumps(
@ -216,7 +397,9 @@ class A2ACustomAgent(BaseAgent):
indent=2, indent=2,
) )
response_parts.append( response_parts.append(
Part(text=f"```json\n{json_text}\n```") Part(
text=f"```json\n{json_text}\n```"
)
) )
except Exception: except Exception:
response_parts.append( response_parts.append(
@ -236,7 +419,9 @@ class A2ACustomAgent(BaseAgent):
content=Content( content=Content(
role="agent", role="agent",
parts=[ parts=[
Part(text="Empty response from agent.") Part(
text="Empty response from agent."
)
], ],
), ),
) )
@ -258,7 +443,9 @@ class A2ACustomAgent(BaseAgent):
content=Content( content=Content(
role="agent", role="agent",
parts=[ parts=[
Part(text="The task failed during processing.") Part(
text="The task failed during processing."
)
], ],
), ),
) )
@ -283,7 +470,7 @@ class A2ACustomAgent(BaseAgent):
break # Exit the loop of checking break # Exit the loop of checking
except Exception as e: except Exception as e:
print(f"Error checking task: {str(e)}") logger.error(f"Error checking task: {str(e)}")
# If the timeout was exceeded, inform the user # If the timeout was exceeded, inform the user
if time.time() - start_time > self.max_wait_time: if time.time() - start_time > self.max_wait_time:
@ -313,9 +500,21 @@ class A2ACustomAgent(BaseAgent):
), ),
) )
except Exception as e:
logger.error(f"Error sending task: {str(e)}")
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(text=f"Error interacting with A2A agent: {str(e)}")
],
),
)
except Exception as e: except Exception as e:
# Handle any uncaught error # Handle any uncaught error
print(f"Error executing A2A agent: {str(e)}") logger.error(f"Error executing A2A agent: {str(e)}")
yield Event( yield Event(
author=self.name, author=self.name,
content=Content( content=Content(

View File

@ -55,6 +55,8 @@ from src.schemas.a2a_types import (
AgentCard, AgentCard,
AgentCapabilities, AgentCapabilities,
AgentSkill, AgentSkill,
AgentAuthentication,
AgentProvider,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -558,7 +560,6 @@ class A2AService:
for tool_name in mcp_tools: for tool_name in mcp_tools:
logger.info(f"Processing tool: {tool_name}") logger.info(f"Processing tool: {tool_name}")
# Buscar informações da ferramenta pelo ID
tool_info = None tool_info = None
if hasattr(mcp_server, "tools") and isinstance( if hasattr(mcp_server, "tools") and isinstance(
mcp_server.tools, list mcp_server.tools, list
@ -633,10 +634,18 @@ class A2AService:
name=agent.name, name=agent.name,
description=agent.description or "", description=agent.description or "",
url=f"{settings.API_URL}/api/v1/a2a/{agent_id}", url=f"{settings.API_URL}/api/v1/a2a/{agent_id}",
version="1.0.0", provider=AgentProvider(
organization=settings.ORGANIZATION_NAME,
url=settings.ORGANIZATION_URL,
),
version=f"{settings.API_VERSION}",
capabilities=capabilities,
authentication=AgentAuthentication(
schemes=["apiKey"],
credentials="x-api-key",
),
defaultInputModes=["text"], defaultInputModes=["text"],
defaultOutputModes=["text"], defaultOutputModes=["text"],
capabilities=capabilities,
skills=skills, skills=skills,
) )