From 16e9747cce80dceb63db5b0f5b1bf03f83151c86 Mon Sep 17 00:00:00 2001 From: Davidson Gomes Date: Mon, 5 May 2025 21:46:24 -0300 Subject: [PATCH] refactor(a2a): enhance A2A agent functionality with improved error handling, logging, and support for streaming requests --- src/services/a2a_agent.py | 501 +++++++++++++++++++++---------- src/services/a2a_task_manager.py | 15 +- 2 files changed, 362 insertions(+), 154 deletions(-) diff --git a/src/services/a2a_agent.py b/src/services/a2a_agent.py index 135d73bc..5d0c7886 100644 --- a/src/services/a2a_agent.py +++ b/src/services/a2a_agent.py @@ -2,23 +2,28 @@ from google.adk.agents import BaseAgent from google.adk.agents.invocation_context import InvocationContext from google.adk.events import Event from google.genai.types import Content, Part -from typing import AsyncGenerator +from typing import AsyncGenerator, Dict, Any, Optional import json import asyncio import time +import logging from src.schemas.a2a_types import ( GetTaskRequest, SendTaskRequest, + SendTaskStreamingRequest, Message, TextPart, TaskState, + JSONRPCRequest, ) import httpx from uuid import uuid4 +logger = logging.getLogger(__name__) + class A2ACustomAgent(BaseAgent): """ @@ -29,17 +34,22 @@ class A2ACustomAgent(BaseAgent): # Field declarations for Pydantic agent_card_url: str + api_key: Optional[str] poll_interval: float max_wait_time: int timeout: int + streaming: bool + base_url: Optional[str] = None def __init__( self, name: str, agent_card_url: str, + api_key: Optional[str] = None, poll_interval: float = 1.0, max_wait_time: int = 60, timeout: int = 300, + streaming: bool = True, **kwargs, ): """ @@ -48,21 +58,83 @@ class A2ACustomAgent(BaseAgent): Args: name: Agent name agent_card_url: A2A agent card URL + api_key: API key for authentication poll_interval: Status check interval (seconds) max_wait_time: Maximum wait time for a task (seconds) timeout: Maximum execution time (seconds) + streaming: Whether to use streaming mode """ - # Initialize base class + # Get base URL by removing agent.json if present + derived_base_url = agent_card_url + if "/.well-known/agent.json" in derived_base_url: + derived_base_url = derived_base_url.split("/.well-known/agent.json")[0] + + # Initialize base class with all fields including base_url super().__init__( name=name, agent_card_url=agent_card_url, + api_key=api_key, poll_interval=poll_interval, max_wait_time=max_wait_time, timeout=timeout, + streaming=streaming, + base_url=derived_base_url, **kwargs, ) - print(f"A2A agent initialized for URL: {agent_card_url}") + logger.info(f"A2A agent initialized for URL: {agent_card_url}") + + # Default headers + self.headers = {"Content-Type": "application/json"} + if api_key: + self.headers["x-api-key"] = api_key + + async def _send_jsonrpc_request(self, request: JSONRPCRequest) -> Dict[str, Any]: + """ + Send a JSON-RPC request to the A2A endpoint. + + Args: + request: The JSON-RPC request object + + Returns: + Dict containing the response + """ + try: + async with httpx.AsyncClient() as client: + logger.debug( + f"Sending request to {self.base_url}: {request.model_dump()}" + ) + response = await client.post( + self.base_url, + json=request.model_dump(), + headers=self.headers, + timeout=30, + ) + + response.raise_for_status() + response_data = response.json() + logger.debug(f"Received response: {response_data}") + + # Check for JSON-RPC errors + if "error" in response_data and response_data["error"]: + error_msg = response_data["error"].get("message", "Unknown error") + error_code = response_data["error"].get("code", -1) + logger.error(f"JSON-RPC error {error_code}: {error_msg}") + raise ValueError(f"A2A server error: {error_msg}") + + return response_data + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error: {e.response.status_code} - {e}") + raise ValueError(f"HTTP error {e.response.status_code}: {str(e)}") + except httpx.RequestError as e: + logger.error(f"Request error: {e}") + raise ValueError(f"Request error: {str(e)}") + except json.JSONDecodeError as e: + logger.error(f"JSON decode error: {e}") + raise ValueError(f"Invalid JSON response: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error: {e}") + raise ValueError(f"Error communicating with A2A server: {str(e)}") async def _run_async_impl( self, ctx: InvocationContext @@ -77,13 +149,6 @@ class A2ACustomAgent(BaseAgent): yield Event(author=self.name) try: - # Prepare the base URL for the A2A - url = self.agent_card_url - - # Ensure that there is no /.well-known/agent.json in the url - if "/.well-known/agent.json" in url: - url = url.split("/.well-known/agent.json")[0] - # 2. Extract the user's message from the context user_message = None @@ -92,7 +157,7 @@ class A2ACustomAgent(BaseAgent): for event in reversed(ctx.session.events): if event.author == "user" and event.content and event.content.parts: user_message = event.content.parts[0].text - print("Message found in session events") + logger.info("Message found in session events") break # Check in the session state if the message was not found in the events @@ -102,8 +167,21 @@ class A2ACustomAgent(BaseAgent): elif "message" in ctx.session.state: user_message = ctx.session.state["message"] + if not user_message: + yield Event( + author=self.name, + content=Content( + role="agent", + parts=[ + Part(text="Error: No message found to send to A2A agent") + ], + ), + ) + yield Event(author=self.name) # Final event + return + # 3. Create and send the task to the A2A agent - print(f"Sending task to A2A agent: {user_message[:100]}...") + logger.info(f"Sending task to A2A agent: {user_message[:100]}...") # Use the session ID as a stable identifier session_id = ( @@ -112,210 +190,331 @@ class A2ACustomAgent(BaseAgent): else str(uuid4()) ) task_id = str(uuid4()) + request_id = str(uuid4()) try: - - formatted_message: Message = Message( + # Format message according to A2A protocol + formatted_message = Message( role="user", parts=[TextPart(type="text", text=user_message)], ) - request: SendTaskRequest = SendTaskRequest( - params={ - "message": formatted_message, - "sessionId": session_id, - "id": task_id, - } - ) - - print(f"Request send task: {request.model_dump()}") - - # REQUEST POST to url when jsonrpc is 2.0 - task_result = await httpx.AsyncClient().post( - url, json=request.model_dump(), timeout=30 - ) - - print(f"Task response: {task_result.json()}") - print(f"Task sent successfully, ID: {task_id}") + # Prepare standard params for A2A + task_params = { + "id": task_id, + "sessionId": session_id, + "message": formatted_message, + "acceptedOutputModes": ["text"], + } + # Emit processing message yield Event( author=self.name, content=Content( - role="agent", parts=[Part(text="Processing request...")] + role="agent", parts=[Part(text="Processing your request...")] ), ) - except Exception as e: - error_msg = f"Error sending request: {str(e)}" - print(error_msg) - print(f"Error type: {type(e).__name__}") - print(f"Error details: {str(e)}") - yield Event( - author=self.name, - content=Content(role="agent", parts=[Part(text=str(e))]), - ) - yield Event(author=self.name) # Final event - return - - start_time = time.time() - - while time.time() - start_time < self.timeout: - try: - # Check current status - request: GetTaskRequest = GetTaskRequest(params={"id": task_id}) - - task_status_response = await httpx.AsyncClient().post( - url, json=request.model_dump(), timeout=30 + if self.streaming: + # Use streaming mode + request = SendTaskStreamingRequest( + id=request_id, params=task_params ) - print(f"Response get task: {task_status_response.json()}") + # Handle streaming response + accumulated_response = "" - task_status = task_status_response.json() + try: + async with httpx.AsyncClient(timeout=None) as client: + async with client.stream( + "POST", + self.base_url, + json=request.model_dump(), + headers=self.headers, + ) as response: + response.raise_for_status() - if "result" not in task_status or not task_status["result"]: - await asyncio.sleep(self.poll_interval) - continue + # Process SSE events + async for line in response.aiter_lines(): + if not line or line.strip() == "": + continue - current_state = None - if ( - "status" in task_status["result"] - and task_status["result"]["status"] - ): - if "state" in task_status["result"]["status"]: - current_state = task_status["result"]["status"]["state"] - - print(f"Task status {task_id}: {current_state}") - - # Check if the task was completed - if current_state in [ - TaskState.COMPLETED, - TaskState.FAILED, - TaskState.CANCELED, - ]: - if current_state == TaskState.COMPLETED: - # Extract the response - if ( - "status" in task_status["result"] - and "message" in task_status["result"]["status"] - and "parts" - in task_status["result"]["status"]["message"] - ): - - # Convert A2A parts to ADK - response_parts = [] - for part in task_status["result"]["status"]["message"][ - "parts" - ]: - if "text" in part and part["text"]: - response_parts.append(Part(text=part["text"])) - elif "data" in part: + if line.startswith("data:"): + data = line[5:].strip() try: - json_text = json.dumps( - part["data"], - ensure_ascii=False, - indent=2, - ) - response_parts.append( - Part(text=f"```json\n{json_text}\n```") - ) - except Exception: - response_parts.append( - Part(text="[Unserializable data]") + event_data = json.loads(data) + logger.debug(f"SSE event: {event_data}") + + # Process artifacts + if ( + "result" in event_data + and "artifact" in event_data["result"] + ): + artifact = event_data["result"][ + "artifact" + ] + if artifact and "parts" in artifact: + for part in artifact["parts"]: + if ( + "text" in part + and part["text"] + ): + accumulated_response += ( + part["text"] + ) + + # Emit incremental update + yield Event( + author=self.name, + content=Content( + role="agent", + parts=[ + Part( + text=accumulated_response + ) + ], + ), + ) + + # Check if task is complete + if ( + "result" in event_data + and "status" in event_data["result"] + and "final" in event_data["result"] + and event_data["result"]["final"] + ): + logger.info( + "Task completed in streaming mode" + ) + break + + except json.JSONDecodeError as e: + logger.error( + f"Error parsing SSE event: {e}" ) - if response_parts: + except Exception as e: + logger.error(f"Error in streaming mode: {e}") + yield Event( + author=self.name, + content=Content( + role="agent", + parts=[Part(text=f"Error in streaming mode: {str(e)}")], + ), + ) + + # If we have a response, we're done + if accumulated_response: + yield Event(author=self.name) # Final event + return + + else: + # Use non-streaming mode + request = SendTaskRequest(id=request_id, params=task_params) + + # Make the request + response_data = await self._send_jsonrpc_request(request) + + # Process the response + if "result" in response_data and response_data["result"]: + task_result = response_data["result"] + + # Extract message from response + if ( + "status" in task_result + and "message" in task_result["status"] + and "parts" in task_result["status"]["message"] + ): + + parts = task_result["status"]["message"]["parts"] + for part in parts: + if "text" in part and part["text"]: yield Event( author=self.name, content=Content( - role="agent", parts=response_parts + role="agent", + parts=[Part(text=part["text"])], ), ) + yield Event(author=self.name) # Final event + return + + # If we reach here, we need to poll for results + start_time = time.time() + + while time.time() - start_time < self.timeout: + try: + # Check current status + status_request = GetTaskRequest( + id=request_id, params={"id": task_id, "historyLength": 10} + ) + + task_status = await self._send_jsonrpc_request(status_request) + + if "result" not in task_status or not task_status["result"]: + await asyncio.sleep(self.poll_interval) + continue + + current_state = None + if ( + "status" in task_status["result"] + and task_status["result"]["status"] + ): + if "state" in task_status["result"]["status"]: + current_state = task_status["result"]["status"]["state"] + + logger.info(f"Task status {task_id}: {current_state}") + + # Check if the task was completed + if current_state in [ + TaskState.COMPLETED, + TaskState.FAILED, + TaskState.CANCELED, + ]: + if current_state == TaskState.COMPLETED: + # Extract the response + if ( + "status" in task_status["result"] + and "message" in task_status["result"]["status"] + and "parts" + in task_status["result"]["status"]["message"] + ): + + # Convert A2A parts to ADK + response_parts = [] + for part in task_status["result"]["status"][ + "message" + ]["parts"]: + if "text" in part and part["text"]: + response_parts.append( + Part(text=part["text"]) + ) + elif "data" in part: + try: + json_text = json.dumps( + part["data"], + ensure_ascii=False, + indent=2, + ) + response_parts.append( + Part( + text=f"```json\n{json_text}\n```" + ) + ) + except Exception: + response_parts.append( + Part(text="[Unserializable data]") + ) + + if response_parts: + yield Event( + author=self.name, + content=Content( + role="agent", parts=response_parts + ), + ) + else: + yield Event( + author=self.name, + content=Content( + role="agent", + parts=[ + Part( + text="Empty response from agent." + ) + ], + ), + ) else: yield Event( author=self.name, content=Content( role="agent", parts=[ - Part(text="Empty response from agent.") + Part( + text="Task completed, but no response message." + ) ], ), ) - else: + elif current_state == TaskState.FAILED: yield Event( author=self.name, content=Content( role="agent", parts=[ Part( - text="Task completed, but no response message." + text="The task failed during processing." ) ], ), ) - elif current_state == TaskState.FAILED: + else: # CANCELED + yield Event( + author=self.name, + content=Content( + role="agent", + parts=[Part(text="The task was canceled.")], + ), + ) + + # Store in the session state for future reference + if ctx.session: + try: + ctx.session.state["a2a_task_result"] = task_status[ + "result" + ] + except Exception: + pass + + break # Exit the loop of checking + + except Exception as e: + logger.error(f"Error checking task: {str(e)}") + + # If the timeout was exceeded, inform the user + if time.time() - start_time > self.max_wait_time: yield Event( author=self.name, content=Content( role="agent", - parts=[ - Part(text="The task failed during processing.") - ], - ), - ) - else: # CANCELED - yield Event( - author=self.name, - content=Content( - role="agent", - parts=[Part(text="The task was canceled.")], + parts=[Part(text=f"Error checking task: {str(e)}")], ), ) + break - # Store in the session state for future reference - if ctx.session: - try: - ctx.session.state["a2a_task_result"] = task_status[ - "result" - ] - except Exception: - pass + # Wait before the next check + await asyncio.sleep(self.poll_interval) - break # Exit the loop of checking + # If the timeout was exceeded + if time.time() - start_time >= self.timeout: + yield Event( + author=self.name, + content=Content( + role="agent", + parts=[ + Part( + text="The operation exceeded the timeout. Please try again later." + ) + ], + ), + ) - except Exception as e: - print(f"Error checking task: {str(e)}") - - # If the timeout was exceeded, inform the user - if time.time() - start_time > self.max_wait_time: - yield Event( - author=self.name, - content=Content( - role="agent", - parts=[Part(text=f"Error checking task: {str(e)}")], - ), - ) - break - - # Wait before the next check - await asyncio.sleep(self.poll_interval) - - # If the timeout was exceeded - if time.time() - start_time >= self.timeout: + except Exception as e: + logger.error(f"Error sending task: {str(e)}") yield Event( author=self.name, content=Content( role="agent", parts=[ - Part( - text="The operation exceeded the timeout. Please try again later." - ) + Part(text=f"Error interacting with A2A agent: {str(e)}") ], ), ) except Exception as e: # Handle any uncaught error - print(f"Error executing A2A agent: {str(e)}") + logger.error(f"Error executing A2A agent: {str(e)}") yield Event( author=self.name, content=Content( diff --git a/src/services/a2a_task_manager.py b/src/services/a2a_task_manager.py index 08280270..af882966 100644 --- a/src/services/a2a_task_manager.py +++ b/src/services/a2a_task_manager.py @@ -55,6 +55,8 @@ from src.schemas.a2a_types import ( AgentCard, AgentCapabilities, AgentSkill, + AgentAuthentication, + AgentProvider, ) logger = logging.getLogger(__name__) @@ -558,7 +560,6 @@ class A2AService: for tool_name in mcp_tools: logger.info(f"Processing tool: {tool_name}") - # Buscar informações da ferramenta pelo ID tool_info = None if hasattr(mcp_server, "tools") and isinstance( mcp_server.tools, list @@ -633,10 +634,18 @@ class A2AService: name=agent.name, description=agent.description or "", url=f"{settings.API_URL}/api/v1/a2a/{agent_id}", - version="1.0.0", + provider=AgentProvider( + organization=settings.ORGANIZATION_NAME, + url=settings.ORGANIZATION_URL, + ), + version=f"{settings.API_VERSION}", + capabilities=capabilities, + authentication=AgentAuthentication( + schemes=["apiKey"], + credentials="x-api-key", + ), defaultInputModes=["text"], defaultOutputModes=["text"], - capabilities=capabilities, skills=skills, )