refactor(a2a): enhance A2A agent functionality with improved error handling, logging, and support for streaming requests

This commit is contained in:
Davidson Gomes 2025-05-05 21:46:24 -03:00
parent ec9dc07d71
commit 16e9747cce
2 changed files with 362 additions and 154 deletions

View File

@ -2,23 +2,28 @@ from google.adk.agents import BaseAgent
from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.invocation_context import InvocationContext
from google.adk.events import Event from google.adk.events import Event
from google.genai.types import Content, Part from google.genai.types import Content, Part
from typing import AsyncGenerator from typing import AsyncGenerator, Dict, Any, Optional
import json import json
import asyncio import asyncio
import time import time
import logging
from src.schemas.a2a_types import ( from src.schemas.a2a_types import (
GetTaskRequest, GetTaskRequest,
SendTaskRequest, SendTaskRequest,
SendTaskStreamingRequest,
Message, Message,
TextPart, TextPart,
TaskState, TaskState,
JSONRPCRequest,
) )
import httpx import httpx
from uuid import uuid4 from uuid import uuid4
logger = logging.getLogger(__name__)
class A2ACustomAgent(BaseAgent): class A2ACustomAgent(BaseAgent):
""" """
@ -29,17 +34,22 @@ class A2ACustomAgent(BaseAgent):
# Field declarations for Pydantic # Field declarations for Pydantic
agent_card_url: str agent_card_url: str
api_key: Optional[str]
poll_interval: float poll_interval: float
max_wait_time: int max_wait_time: int
timeout: int timeout: int
streaming: bool
base_url: Optional[str] = None
def __init__( def __init__(
self, self,
name: str, name: str,
agent_card_url: str, agent_card_url: str,
api_key: Optional[str] = None,
poll_interval: float = 1.0, poll_interval: float = 1.0,
max_wait_time: int = 60, max_wait_time: int = 60,
timeout: int = 300, timeout: int = 300,
streaming: bool = True,
**kwargs, **kwargs,
): ):
""" """
@ -48,21 +58,83 @@ class A2ACustomAgent(BaseAgent):
Args: Args:
name: Agent name name: Agent name
agent_card_url: A2A agent card URL agent_card_url: A2A agent card URL
api_key: API key for authentication
poll_interval: Status check interval (seconds) poll_interval: Status check interval (seconds)
max_wait_time: Maximum wait time for a task (seconds) max_wait_time: Maximum wait time for a task (seconds)
timeout: Maximum execution time (seconds) timeout: Maximum execution time (seconds)
streaming: Whether to use streaming mode
""" """
# Initialize base class # Get base URL by removing agent.json if present
derived_base_url = agent_card_url
if "/.well-known/agent.json" in derived_base_url:
derived_base_url = derived_base_url.split("/.well-known/agent.json")[0]
# Initialize base class with all fields including base_url
super().__init__( super().__init__(
name=name, name=name,
agent_card_url=agent_card_url, agent_card_url=agent_card_url,
api_key=api_key,
poll_interval=poll_interval, poll_interval=poll_interval,
max_wait_time=max_wait_time, max_wait_time=max_wait_time,
timeout=timeout, timeout=timeout,
streaming=streaming,
base_url=derived_base_url,
**kwargs, **kwargs,
) )
print(f"A2A agent initialized for URL: {agent_card_url}") logger.info(f"A2A agent initialized for URL: {agent_card_url}")
# Default headers
self.headers = {"Content-Type": "application/json"}
if api_key:
self.headers["x-api-key"] = api_key
async def _send_jsonrpc_request(self, request: JSONRPCRequest) -> Dict[str, Any]:
"""
Send a JSON-RPC request to the A2A endpoint.
Args:
request: The JSON-RPC request object
Returns:
Dict containing the response
"""
try:
async with httpx.AsyncClient() as client:
logger.debug(
f"Sending request to {self.base_url}: {request.model_dump()}"
)
response = await client.post(
self.base_url,
json=request.model_dump(),
headers=self.headers,
timeout=30,
)
response.raise_for_status()
response_data = response.json()
logger.debug(f"Received response: {response_data}")
# Check for JSON-RPC errors
if "error" in response_data and response_data["error"]:
error_msg = response_data["error"].get("message", "Unknown error")
error_code = response_data["error"].get("code", -1)
logger.error(f"JSON-RPC error {error_code}: {error_msg}")
raise ValueError(f"A2A server error: {error_msg}")
return response_data
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error: {e.response.status_code} - {e}")
raise ValueError(f"HTTP error {e.response.status_code}: {str(e)}")
except httpx.RequestError as e:
logger.error(f"Request error: {e}")
raise ValueError(f"Request error: {str(e)}")
except json.JSONDecodeError as e:
logger.error(f"JSON decode error: {e}")
raise ValueError(f"Invalid JSON response: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error: {e}")
raise ValueError(f"Error communicating with A2A server: {str(e)}")
async def _run_async_impl( async def _run_async_impl(
self, ctx: InvocationContext self, ctx: InvocationContext
@ -77,13 +149,6 @@ class A2ACustomAgent(BaseAgent):
yield Event(author=self.name) yield Event(author=self.name)
try: try:
# Prepare the base URL for the A2A
url = self.agent_card_url
# Ensure that there is no /.well-known/agent.json in the url
if "/.well-known/agent.json" in url:
url = url.split("/.well-known/agent.json")[0]
# 2. Extract the user's message from the context # 2. Extract the user's message from the context
user_message = None user_message = None
@ -92,7 +157,7 @@ class A2ACustomAgent(BaseAgent):
for event in reversed(ctx.session.events): for event in reversed(ctx.session.events):
if event.author == "user" and event.content and event.content.parts: if event.author == "user" and event.content and event.content.parts:
user_message = event.content.parts[0].text user_message = event.content.parts[0].text
print("Message found in session events") logger.info("Message found in session events")
break break
# Check in the session state if the message was not found in the events # Check in the session state if the message was not found in the events
@ -102,8 +167,21 @@ class A2ACustomAgent(BaseAgent):
elif "message" in ctx.session.state: elif "message" in ctx.session.state:
user_message = ctx.session.state["message"] user_message = ctx.session.state["message"]
if not user_message:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(text="Error: No message found to send to A2A agent")
],
),
)
yield Event(author=self.name) # Final event
return
# 3. Create and send the task to the A2A agent # 3. Create and send the task to the A2A agent
print(f"Sending task to A2A agent: {user_message[:100]}...") logger.info(f"Sending task to A2A agent: {user_message[:100]}...")
# Use the session ID as a stable identifier # Use the session ID as a stable identifier
session_id = ( session_id = (
@ -112,210 +190,331 @@ class A2ACustomAgent(BaseAgent):
else str(uuid4()) else str(uuid4())
) )
task_id = str(uuid4()) task_id = str(uuid4())
request_id = str(uuid4())
try: try:
# Format message according to A2A protocol
formatted_message: Message = Message( formatted_message = Message(
role="user", role="user",
parts=[TextPart(type="text", text=user_message)], parts=[TextPart(type="text", text=user_message)],
) )
request: SendTaskRequest = SendTaskRequest( # Prepare standard params for A2A
params={ task_params = {
"message": formatted_message, "id": task_id,
"sessionId": session_id, "sessionId": session_id,
"id": task_id, "message": formatted_message,
} "acceptedOutputModes": ["text"],
) }
print(f"Request send task: {request.model_dump()}")
# REQUEST POST to url when jsonrpc is 2.0
task_result = await httpx.AsyncClient().post(
url, json=request.model_dump(), timeout=30
)
print(f"Task response: {task_result.json()}")
print(f"Task sent successfully, ID: {task_id}")
# Emit processing message
yield Event( yield Event(
author=self.name, author=self.name,
content=Content( content=Content(
role="agent", parts=[Part(text="Processing request...")] role="agent", parts=[Part(text="Processing your request...")]
), ),
) )
except Exception as e:
error_msg = f"Error sending request: {str(e)}"
print(error_msg)
print(f"Error type: {type(e).__name__}")
print(f"Error details: {str(e)}")
yield Event( if self.streaming:
author=self.name, # Use streaming mode
content=Content(role="agent", parts=[Part(text=str(e))]), request = SendTaskStreamingRequest(
) id=request_id, params=task_params
yield Event(author=self.name) # Final event
return
start_time = time.time()
while time.time() - start_time < self.timeout:
try:
# Check current status
request: GetTaskRequest = GetTaskRequest(params={"id": task_id})
task_status_response = await httpx.AsyncClient().post(
url, json=request.model_dump(), timeout=30
) )
print(f"Response get task: {task_status_response.json()}") # Handle streaming response
accumulated_response = ""
task_status = task_status_response.json() try:
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream(
"POST",
self.base_url,
json=request.model_dump(),
headers=self.headers,
) as response:
response.raise_for_status()
if "result" not in task_status or not task_status["result"]: # Process SSE events
await asyncio.sleep(self.poll_interval) async for line in response.aiter_lines():
continue if not line or line.strip() == "":
continue
current_state = None if line.startswith("data:"):
if ( data = line[5:].strip()
"status" in task_status["result"]
and task_status["result"]["status"]
):
if "state" in task_status["result"]["status"]:
current_state = task_status["result"]["status"]["state"]
print(f"Task status {task_id}: {current_state}")
# Check if the task was completed
if current_state in [
TaskState.COMPLETED,
TaskState.FAILED,
TaskState.CANCELED,
]:
if current_state == TaskState.COMPLETED:
# Extract the response
if (
"status" in task_status["result"]
and "message" in task_status["result"]["status"]
and "parts"
in task_status["result"]["status"]["message"]
):
# Convert A2A parts to ADK
response_parts = []
for part in task_status["result"]["status"]["message"][
"parts"
]:
if "text" in part and part["text"]:
response_parts.append(Part(text=part["text"]))
elif "data" in part:
try: try:
json_text = json.dumps( event_data = json.loads(data)
part["data"], logger.debug(f"SSE event: {event_data}")
ensure_ascii=False,
indent=2, # Process artifacts
) if (
response_parts.append( "result" in event_data
Part(text=f"```json\n{json_text}\n```") and "artifact" in event_data["result"]
) ):
except Exception: artifact = event_data["result"][
response_parts.append( "artifact"
Part(text="[Unserializable data]") ]
if artifact and "parts" in artifact:
for part in artifact["parts"]:
if (
"text" in part
and part["text"]
):
accumulated_response += (
part["text"]
)
# Emit incremental update
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(
text=accumulated_response
)
],
),
)
# Check if task is complete
if (
"result" in event_data
and "status" in event_data["result"]
and "final" in event_data["result"]
and event_data["result"]["final"]
):
logger.info(
"Task completed in streaming mode"
)
break
except json.JSONDecodeError as e:
logger.error(
f"Error parsing SSE event: {e}"
) )
if response_parts: except Exception as e:
logger.error(f"Error in streaming mode: {e}")
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[Part(text=f"Error in streaming mode: {str(e)}")],
),
)
# If we have a response, we're done
if accumulated_response:
yield Event(author=self.name) # Final event
return
else:
# Use non-streaming mode
request = SendTaskRequest(id=request_id, params=task_params)
# Make the request
response_data = await self._send_jsonrpc_request(request)
# Process the response
if "result" in response_data and response_data["result"]:
task_result = response_data["result"]
# Extract message from response
if (
"status" in task_result
and "message" in task_result["status"]
and "parts" in task_result["status"]["message"]
):
parts = task_result["status"]["message"]["parts"]
for part in parts:
if "text" in part and part["text"]:
yield Event( yield Event(
author=self.name, author=self.name,
content=Content( content=Content(
role="agent", parts=response_parts role="agent",
parts=[Part(text=part["text"])],
), ),
) )
yield Event(author=self.name) # Final event
return
# If we reach here, we need to poll for results
start_time = time.time()
while time.time() - start_time < self.timeout:
try:
# Check current status
status_request = GetTaskRequest(
id=request_id, params={"id": task_id, "historyLength": 10}
)
task_status = await self._send_jsonrpc_request(status_request)
if "result" not in task_status or not task_status["result"]:
await asyncio.sleep(self.poll_interval)
continue
current_state = None
if (
"status" in task_status["result"]
and task_status["result"]["status"]
):
if "state" in task_status["result"]["status"]:
current_state = task_status["result"]["status"]["state"]
logger.info(f"Task status {task_id}: {current_state}")
# Check if the task was completed
if current_state in [
TaskState.COMPLETED,
TaskState.FAILED,
TaskState.CANCELED,
]:
if current_state == TaskState.COMPLETED:
# Extract the response
if (
"status" in task_status["result"]
and "message" in task_status["result"]["status"]
and "parts"
in task_status["result"]["status"]["message"]
):
# Convert A2A parts to ADK
response_parts = []
for part in task_status["result"]["status"][
"message"
]["parts"]:
if "text" in part and part["text"]:
response_parts.append(
Part(text=part["text"])
)
elif "data" in part:
try:
json_text = json.dumps(
part["data"],
ensure_ascii=False,
indent=2,
)
response_parts.append(
Part(
text=f"```json\n{json_text}\n```"
)
)
except Exception:
response_parts.append(
Part(text="[Unserializable data]")
)
if response_parts:
yield Event(
author=self.name,
content=Content(
role="agent", parts=response_parts
),
)
else:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(
text="Empty response from agent."
)
],
),
)
else: else:
yield Event( yield Event(
author=self.name, author=self.name,
content=Content( content=Content(
role="agent", role="agent",
parts=[ parts=[
Part(text="Empty response from agent.") Part(
text="Task completed, but no response message."
)
], ],
), ),
) )
else: elif current_state == TaskState.FAILED:
yield Event( yield Event(
author=self.name, author=self.name,
content=Content( content=Content(
role="agent", role="agent",
parts=[ parts=[
Part( Part(
text="Task completed, but no response message." text="The task failed during processing."
) )
], ],
), ),
) )
elif current_state == TaskState.FAILED: else: # CANCELED
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[Part(text="The task was canceled.")],
),
)
# Store in the session state for future reference
if ctx.session:
try:
ctx.session.state["a2a_task_result"] = task_status[
"result"
]
except Exception:
pass
break # Exit the loop of checking
except Exception as e:
logger.error(f"Error checking task: {str(e)}")
# If the timeout was exceeded, inform the user
if time.time() - start_time > self.max_wait_time:
yield Event( yield Event(
author=self.name, author=self.name,
content=Content( content=Content(
role="agent", role="agent",
parts=[ parts=[Part(text=f"Error checking task: {str(e)}")],
Part(text="The task failed during processing.")
],
),
)
else: # CANCELED
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[Part(text="The task was canceled.")],
), ),
) )
break
# Store in the session state for future reference # Wait before the next check
if ctx.session: await asyncio.sleep(self.poll_interval)
try:
ctx.session.state["a2a_task_result"] = task_status[
"result"
]
except Exception:
pass
break # Exit the loop of checking # If the timeout was exceeded
if time.time() - start_time >= self.timeout:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(
text="The operation exceeded the timeout. Please try again later."
)
],
),
)
except Exception as e: except Exception as e:
print(f"Error checking task: {str(e)}") logger.error(f"Error sending task: {str(e)}")
# If the timeout was exceeded, inform the user
if time.time() - start_time > self.max_wait_time:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[Part(text=f"Error checking task: {str(e)}")],
),
)
break
# Wait before the next check
await asyncio.sleep(self.poll_interval)
# If the timeout was exceeded
if time.time() - start_time >= self.timeout:
yield Event( yield Event(
author=self.name, author=self.name,
content=Content( content=Content(
role="agent", role="agent",
parts=[ parts=[
Part( Part(text=f"Error interacting with A2A agent: {str(e)}")
text="The operation exceeded the timeout. Please try again later."
)
], ],
), ),
) )
except Exception as e: except Exception as e:
# Handle any uncaught error # Handle any uncaught error
print(f"Error executing A2A agent: {str(e)}") logger.error(f"Error executing A2A agent: {str(e)}")
yield Event( yield Event(
author=self.name, author=self.name,
content=Content( content=Content(

View File

@ -55,6 +55,8 @@ from src.schemas.a2a_types import (
AgentCard, AgentCard,
AgentCapabilities, AgentCapabilities,
AgentSkill, AgentSkill,
AgentAuthentication,
AgentProvider,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -558,7 +560,6 @@ class A2AService:
for tool_name in mcp_tools: for tool_name in mcp_tools:
logger.info(f"Processing tool: {tool_name}") logger.info(f"Processing tool: {tool_name}")
# Buscar informações da ferramenta pelo ID
tool_info = None tool_info = None
if hasattr(mcp_server, "tools") and isinstance( if hasattr(mcp_server, "tools") and isinstance(
mcp_server.tools, list mcp_server.tools, list
@ -633,10 +634,18 @@ class A2AService:
name=agent.name, name=agent.name,
description=agent.description or "", description=agent.description or "",
url=f"{settings.API_URL}/api/v1/a2a/{agent_id}", url=f"{settings.API_URL}/api/v1/a2a/{agent_id}",
version="1.0.0", provider=AgentProvider(
organization=settings.ORGANIZATION_NAME,
url=settings.ORGANIZATION_URL,
),
version=f"{settings.API_VERSION}",
capabilities=capabilities,
authentication=AgentAuthentication(
schemes=["apiKey"],
credentials="x-api-key",
),
defaultInputModes=["text"], defaultInputModes=["text"],
defaultOutputModes=["text"], defaultOutputModes=["text"],
capabilities=capabilities,
skills=skills, skills=skills,
) )