From 3622260c119ec222d989d9b4ea0ca466f5e9daa4 Mon Sep 17 00:00:00 2001 From: Davidson Gomes Date: Wed, 14 May 2025 15:10:48 -0300 Subject: [PATCH] refactor(agent_service): sanitize agent names and improve agent card fetching --- src/schemas/a2a_types.py | 2 +- src/services/agent_service.py | 47 +++- src/services/custom_agents/a2a_agent.py | 337 ++++++++++++++++++++---- 3 files changed, 324 insertions(+), 62 deletions(-) diff --git a/src/schemas/a2a_types.py b/src/schemas/a2a_types.py index c63b761f..e5c2c9bb 100644 --- a/src/schemas/a2a_types.py +++ b/src/schemas/a2a_types.py @@ -29,7 +29,7 @@ from datetime import datetime from enum import Enum -from typing import Annotated, Any, Literal +from typing import Annotated, Any, Literal, Union, Dict, List, Optional from uuid import uuid4 from typing_extensions import Self diff --git a/src/services/agent_service.py b/src/services/agent_service.py index bb3fd57c..b699daf8 100644 --- a/src/services/agent_service.py +++ b/src/services/agent_service.py @@ -103,6 +103,14 @@ def get_agent(db: Session, agent_id: Union[uuid.UUID, str]) -> Optional[Agent]: logger.warning(f"Agent not found: {agent_id}") return None + # Sanitize agent name if it contains spaces or special characters + if agent.name and any(c for c in agent.name if not (c.isalnum() or c == "_")): + agent.name = "".join( + c if c.isalnum() or c == "_" else "_" for c in agent.name + ) + # Update in database + db.commit() + return agent except SQLAlchemyError as e: logger.error(f"Error searching for agent {agent_id}: {str(e)}") @@ -144,6 +152,17 @@ def get_agents_by_client( agents = query.offset(skip).limit(limit).all() + # Sanitize agent names if they contain spaces or special characters + for agent in agents: + if agent.name and any( + c for c in agent.name if not (c.isalnum() or c == "_") + ): + agent.name = "".join( + c if c.isalnum() or c == "_" else "_" for c in agent.name + ) + # Update in database + db.commit() + return agents except SQLAlchemyError as e: logger.error(f"Error searching for client agents {client_id}: {str(e)}") @@ -176,7 +195,15 @@ async def create_agent(db: Session, agent: AgentCreate) -> Agent: agent_card = response.json() # Update agent with information from agent card - agent.name = agent_card.get("name", "Unknown Agent") + # Only update name if not provided or empty, or sanitize it + if not agent.name or agent.name.strip() == "": + # Sanitize name: remove spaces and special characters + card_name = agent_card.get("name", "Unknown Agent") + sanitized_name = "".join( + c if c.isalnum() or c == "_" else "_" for c in card_name + ) + agent.name = sanitized_name + agent.description = agent_card.get("description", "") if agent.config is None: @@ -499,7 +526,14 @@ async def update_agent( ) agent_card = response.json() - agent_data["name"] = agent_card.get("name", "Unknown Agent") + # Only update name if the original update doesn't specify a name + if "name" not in agent_data or not agent_data["name"].strip(): + # Sanitize name: remove spaces and special characters + card_name = agent_card.get("name", "Unknown Agent") + sanitized_name = "".join( + c if c.isalnum() or c == "_" else "_" for c in card_name + ) + agent_data["name"] = sanitized_name agent_data["description"] = agent_card.get("description", "") if "config" not in agent_data or agent_data["config"] is None: @@ -537,7 +571,14 @@ async def update_agent( ) agent_card = response.json() - agent_data["name"] = agent_card.get("name", "Unknown Agent") + # Only update name if the original update doesn't specify a name + if "name" not in agent_data or not agent_data["name"].strip(): + # Sanitize name: remove spaces and special characters + card_name = agent_card.get("name", "Unknown Agent") + sanitized_name = "".join( + c if c.isalnum() or c == "_" else "_" for c in card_name + ) + agent_data["name"] = sanitized_name agent_data["description"] = agent_card.get("description", "") if "config" not in agent_data or agent_data["config"] is None: diff --git a/src/services/custom_agents/a2a_agent.py b/src/services/custom_agents/a2a_agent.py index 020179be..ed2588a7 100644 --- a/src/services/custom_agents/a2a_agent.py +++ b/src/services/custom_agents/a2a_agent.py @@ -32,15 +32,21 @@ from google.adk.agents.invocation_context import InvocationContext from google.adk.events import Event from google.genai.types import Content, Part -from typing import AsyncGenerator, List - -from src.schemas.a2a_types import ( - SendTaskRequest, - Message, - TextPart, -) +from typing import AsyncGenerator, List, Dict, Any, Optional +import json import httpx +from httpx_sse import connect_sse + +from src.schemas.a2a_types import ( + AgentCard, + Message, + TextPart, + TaskSendParams, + SendTaskRequest, + SendTaskStreamingRequest, + TaskState, +) from uuid import uuid4 @@ -54,7 +60,9 @@ class A2ACustomAgent(BaseAgent): # Field declarations for Pydantic agent_card_url: str + agent_card: Optional[AgentCard] timeout: int + base_url: str def __init__( self, @@ -73,16 +81,41 @@ class A2ACustomAgent(BaseAgent): timeout: Maximum execution time (seconds) sub_agents: List of sub-agents to be executed after the A2A agent """ + # Create base_url from agent_card_url + base_url = agent_card_url + if "/.well-known/agent.json" in base_url: + base_url = base_url.split("/.well-known/agent.json")[0] + + print(f"A2A agent initialized for URL: {agent_card_url}") + # Initialize base class super().__init__( name=name, agent_card_url=agent_card_url, + base_url=base_url, # Pass base_url here + agent_card=None, timeout=timeout, sub_agents=sub_agents, **kwargs, ) - print(f"A2A agent initialized for URL: {agent_card_url}") + async def fetch_agent_card(self) -> AgentCard: + """Fetch the agent card from the A2A service.""" + if self.agent_card: + return self.agent_card + + card_url = f"{self.base_url}/.well-known/agent.json" + print(f"Fetching agent card from: {card_url}") + + async with httpx.AsyncClient() as client: + response = await client.get(card_url) + response.raise_for_status() + try: + card_data = response.json() + self.agent_card = AgentCard(**card_data) + return self.agent_card + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse agent card: {str(e)}") async def _run_async_impl( self, ctx: InvocationContext @@ -95,12 +128,18 @@ class A2ACustomAgent(BaseAgent): """ try: - # Prepare the base URL for the A2A - url = self.agent_card_url - - # Ensure that there is no /.well-known/agent.json in the url - if "/.well-known/agent.json" in url: - url = url.split("/.well-known/agent.json")[0] + # 1. First, fetch the agent card if we haven't already + try: + agent_card = await self.fetch_agent_card() + print(f"Agent card fetched: {agent_card.name}") + except Exception as e: + error_msg = f"Failed to fetch agent card: {str(e)}" + print(error_msg) + yield Event( + author=self.name, + content=Content(role="agent", parts=[Part(text=error_msg)]), + ) + return # 2. Extract the user's message from the context user_message = None @@ -120,7 +159,16 @@ class A2ACustomAgent(BaseAgent): elif "message" in ctx.session.state: user_message = ctx.session.state["message"] - # 3. Create and send the task to the A2A agent + if not user_message: + error_msg = "No user message found" + print(error_msg) + yield Event( + author=self.name, + content=Content(role="agent", parts=[Part(text=error_msg)]), + ) + return + + # 3. Create and format the task to send to the A2A agent print(f"Sending task to A2A agent: {user_message[:100]}...") # Use the session ID as a stable identifier @@ -131,66 +179,239 @@ class A2ACustomAgent(BaseAgent): ) task_id = str(uuid4()) - try: + # Prepare the message for the A2A agent + formatted_message = Message( + role="user", + parts=[TextPart(text=user_message)], + ) - formatted_message: Message = Message( - role="user", - parts=[TextPart(type="text", text=user_message)], - ) + # Prepare the task parameters + task_params = TaskSendParams( + id=task_id, + sessionId=session_id, + message=formatted_message, + acceptedOutputModes=["text"], + ) - request: SendTaskRequest = SendTaskRequest( - params={ - "message": formatted_message, - "sessionId": session_id, - "id": task_id, - } - ) + # 4. Check if the agent supports streaming + supports_streaming = ( + agent_card.capabilities.streaming if agent_card.capabilities else False + ) - print(f"Request send task: {request.model_dump()}") + if supports_streaming: + print("Agent supports streaming, using streaming API") + # Process with streaming + try: + # Criar a requisição usando o método correto de tasks/sendSubscribe + request = SendTaskStreamingRequest( + method="tasks/sendSubscribe", params=task_params + ) - # REQUEST POST to url when jsonrpc is 2.0 - task_result = await httpx.AsyncClient().post( - url, json=request.model_dump(), timeout=self.timeout - ) + try: + async with httpx.AsyncClient() as client: + response = await client.post( + self.base_url, + json=request.model_dump(), + headers={"Accept": "text/event-stream"}, + timeout=self.timeout, + ) + response.raise_for_status() - print(f"Task response: {task_result.json()}") - print(f"Task sent successfully, ID: {task_id}") + # Processar manualmente a resposta SSE + async for line in response.aiter_lines(): + if line.startswith("data:"): + data = line[5:].strip() + if data: + try: + result = json.loads(data) + print(f"Stream event received: {result}") - agent_response_parts = task_result.json()["result"]["status"][ - "message" - ]["parts"] + # Check if this is a status update with a message + if ( + "result" in result + and "status" in result["result"] + and "message" + in result["result"]["status"] + and "parts" + in result["result"]["status"]["message"] + ): + message_parts = result["result"][ + "status" + ]["message"]["parts"] + parts = [ + Part(text=part["text"]) + for part in message_parts + if part.get("type") == "text" + and "text" in part + ] - parts = [Part(text=part["text"]) for part in agent_response_parts] + if parts: + yield Event( + author=self.name, + content=Content( + role="agent", parts=parts + ), + ) - yield Event( - author=self.name, - content=Content(role="agent", parts=parts), - ) + # Check if this is a final message + if ( + "result" in result + and result.get("result", {}).get( + "final", False + ) + and "status" in result["result"] + and result["result"]["status"].get( + "state" + ) + in [ + TaskState.COMPLETED, + TaskState.CANCELED, + TaskState.FAILED, + ] + ): + print( + "Received final message, stream complete" + ) + break + except json.JSONDecodeError as e: + print(f"Error parsing SSE data: {str(e)}") + except Exception as stream_error: + print( + f"Error in direct streaming: {str(stream_error)}, falling back to regular API" + ) + # If streaming fails, fall back to regular API + # Criar a requisição usando o método correto de tasks/send + fallback_request = SendTaskRequest( + method="tasks/send", params=task_params + ) - # Run sub-agents - for sub_agent in self.sub_agents: - async for event in sub_agent.run_async(ctx): - yield event + async with httpx.AsyncClient() as client: + response = await client.post( + self.base_url, + json=fallback_request.model_dump(), + timeout=self.timeout, + ) + response.raise_for_status() - except Exception as e: - error_msg = f"Error sending request: {str(e)}" - print(error_msg) - print(f"Error type: {type(e).__name__}") - print(f"Error details: {str(e)}") + result = response.json() + print(f"Fallback response: {result}") - yield Event( - author=self.name, - content=Content(role="agent", parts=[Part(text=error_msg)]), - ) - return + # Extract agent message parts + if ( + "result" in result + and "status" in result["result"] + and "message" in result["result"]["status"] + and "parts" in result["result"]["status"]["message"] + ): + message_parts = result["result"]["status"]["message"][ + "parts" + ] + parts = [ + Part(text=part["text"]) + for part in message_parts + if part.get("type") == "text" and "text" in part + ] + + if parts: + yield Event( + author=self.name, + content=Content(role="agent", parts=parts), + ) + else: + yield Event( + author=self.name, + content=Content( + role="agent", + parts=[ + Part( + text="Received response without message parts" + ) + ], + ), + ) + except Exception as e: + error_msg = f"Error in streaming: {str(e)}" + print(error_msg) + yield Event( + author=self.name, + content=Content(role="agent", parts=[Part(text=error_msg)]), + ) + else: + print("Agent does not support streaming, using regular API") + # Process with regular request + try: + # Criar a requisição usando o método correto de tasks/send + request = SendTaskRequest(method="tasks/send", params=task_params) + + async with httpx.AsyncClient() as client: + response = await client.post( + self.base_url, + json=request.model_dump(), + timeout=self.timeout, + ) + response.raise_for_status() + + result = response.json() + print(f"Task response: {result}") + + # Extract agent message parts + if ( + "result" in result + and "status" in result["result"] + and "message" in result["result"]["status"] + and "parts" in result["result"]["status"]["message"] + ): + message_parts = result["result"]["status"]["message"][ + "parts" + ] + parts = [ + Part(text=part["text"]) + for part in message_parts + if part.get("type") == "text" and "text" in part + ] + + if parts: + yield Event( + author=self.name, + content=Content(role="agent", parts=parts), + ) + else: + yield Event( + author=self.name, + content=Content( + role="agent", + parts=[ + Part( + text="Received response without message parts" + ) + ], + ), + ) + + except Exception as e: + error_msg = f"Error sending request: {str(e)}" + print(error_msg) + print(f"Error type: {type(e).__name__}") + print(f"Error details: {str(e)}") + + yield Event( + author=self.name, + content=Content(role="agent", parts=[Part(text=error_msg)]), + ) + + # Run sub-agents + for sub_agent in self.sub_agents: + async for event in sub_agent.run_async(ctx): + yield event except Exception as e: # Handle any uncaught error - print(f"Error executing A2A agent: {str(e)}") + error_msg = f"Error executing A2A agent: {str(e)}" + print(error_msg) yield Event( author=self.name, content=Content( role="agent", - parts=[Part(text=f"Error interacting with A2A agent: {str(e)}")], + parts=[Part(text=error_msg)], ), )