refactor(agent_service): sanitize agent names and improve agent card fetching

This commit is contained in:
Davidson Gomes 2025-05-14 15:10:48 -03:00
parent 0ca6b4f3e9
commit 3622260c11
3 changed files with 324 additions and 62 deletions

View File

@ -29,7 +29,7 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal, Union, Dict, List, Optional
from uuid import uuid4 from uuid import uuid4
from typing_extensions import Self from typing_extensions import Self

View File

@ -103,6 +103,14 @@ def get_agent(db: Session, agent_id: Union[uuid.UUID, str]) -> Optional[Agent]:
logger.warning(f"Agent not found: {agent_id}") logger.warning(f"Agent not found: {agent_id}")
return None return None
# Sanitize agent name if it contains spaces or special characters
if agent.name and any(c for c in agent.name if not (c.isalnum() or c == "_")):
agent.name = "".join(
c if c.isalnum() or c == "_" else "_" for c in agent.name
)
# Update in database
db.commit()
return agent return agent
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Error searching for agent {agent_id}: {str(e)}") logger.error(f"Error searching for agent {agent_id}: {str(e)}")
@ -144,6 +152,17 @@ def get_agents_by_client(
agents = query.offset(skip).limit(limit).all() agents = query.offset(skip).limit(limit).all()
# Sanitize agent names if they contain spaces or special characters
for agent in agents:
if agent.name and any(
c for c in agent.name if not (c.isalnum() or c == "_")
):
agent.name = "".join(
c if c.isalnum() or c == "_" else "_" for c in agent.name
)
# Update in database
db.commit()
return agents return agents
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Error searching for client agents {client_id}: {str(e)}") logger.error(f"Error searching for client agents {client_id}: {str(e)}")
@ -176,7 +195,15 @@ async def create_agent(db: Session, agent: AgentCreate) -> Agent:
agent_card = response.json() agent_card = response.json()
# Update agent with information from agent card # Update agent with information from agent card
agent.name = agent_card.get("name", "Unknown Agent") # Only update name if not provided or empty, or sanitize it
if not agent.name or agent.name.strip() == "":
# Sanitize name: remove spaces and special characters
card_name = agent_card.get("name", "Unknown Agent")
sanitized_name = "".join(
c if c.isalnum() or c == "_" else "_" for c in card_name
)
agent.name = sanitized_name
agent.description = agent_card.get("description", "") agent.description = agent_card.get("description", "")
if agent.config is None: if agent.config is None:
@ -499,7 +526,14 @@ async def update_agent(
) )
agent_card = response.json() agent_card = response.json()
agent_data["name"] = agent_card.get("name", "Unknown Agent") # Only update name if the original update doesn't specify a name
if "name" not in agent_data or not agent_data["name"].strip():
# Sanitize name: remove spaces and special characters
card_name = agent_card.get("name", "Unknown Agent")
sanitized_name = "".join(
c if c.isalnum() or c == "_" else "_" for c in card_name
)
agent_data["name"] = sanitized_name
agent_data["description"] = agent_card.get("description", "") agent_data["description"] = agent_card.get("description", "")
if "config" not in agent_data or agent_data["config"] is None: if "config" not in agent_data or agent_data["config"] is None:
@ -537,7 +571,14 @@ async def update_agent(
) )
agent_card = response.json() agent_card = response.json()
agent_data["name"] = agent_card.get("name", "Unknown Agent") # Only update name if the original update doesn't specify a name
if "name" not in agent_data or not agent_data["name"].strip():
# Sanitize name: remove spaces and special characters
card_name = agent_card.get("name", "Unknown Agent")
sanitized_name = "".join(
c if c.isalnum() or c == "_" else "_" for c in card_name
)
agent_data["name"] = sanitized_name
agent_data["description"] = agent_card.get("description", "") agent_data["description"] = agent_card.get("description", "")
if "config" not in agent_data or agent_data["config"] is None: if "config" not in agent_data or agent_data["config"] is None:

View File

@ -32,15 +32,21 @@ from google.adk.agents.invocation_context import InvocationContext
from google.adk.events import Event from google.adk.events import Event
from google.genai.types import Content, Part from google.genai.types import Content, Part
from typing import AsyncGenerator, List from typing import AsyncGenerator, List, Dict, Any, Optional
import json
from src.schemas.a2a_types import (
SendTaskRequest,
Message,
TextPart,
)
import httpx import httpx
from httpx_sse import connect_sse
from src.schemas.a2a_types import (
AgentCard,
Message,
TextPart,
TaskSendParams,
SendTaskRequest,
SendTaskStreamingRequest,
TaskState,
)
from uuid import uuid4 from uuid import uuid4
@ -54,7 +60,9 @@ class A2ACustomAgent(BaseAgent):
# Field declarations for Pydantic # Field declarations for Pydantic
agent_card_url: str agent_card_url: str
agent_card: Optional[AgentCard]
timeout: int timeout: int
base_url: str
def __init__( def __init__(
self, self,
@ -73,16 +81,41 @@ class A2ACustomAgent(BaseAgent):
timeout: Maximum execution time (seconds) timeout: Maximum execution time (seconds)
sub_agents: List of sub-agents to be executed after the A2A agent sub_agents: List of sub-agents to be executed after the A2A agent
""" """
# Create base_url from agent_card_url
base_url = agent_card_url
if "/.well-known/agent.json" in base_url:
base_url = base_url.split("/.well-known/agent.json")[0]
print(f"A2A agent initialized for URL: {agent_card_url}")
# Initialize base class # Initialize base class
super().__init__( super().__init__(
name=name, name=name,
agent_card_url=agent_card_url, agent_card_url=agent_card_url,
base_url=base_url, # Pass base_url here
agent_card=None,
timeout=timeout, timeout=timeout,
sub_agents=sub_agents, sub_agents=sub_agents,
**kwargs, **kwargs,
) )
print(f"A2A agent initialized for URL: {agent_card_url}") async def fetch_agent_card(self) -> AgentCard:
"""Fetch the agent card from the A2A service."""
if self.agent_card:
return self.agent_card
card_url = f"{self.base_url}/.well-known/agent.json"
print(f"Fetching agent card from: {card_url}")
async with httpx.AsyncClient() as client:
response = await client.get(card_url)
response.raise_for_status()
try:
card_data = response.json()
self.agent_card = AgentCard(**card_data)
return self.agent_card
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse agent card: {str(e)}")
async def _run_async_impl( async def _run_async_impl(
self, ctx: InvocationContext self, ctx: InvocationContext
@ -95,12 +128,18 @@ class A2ACustomAgent(BaseAgent):
""" """
try: try:
# Prepare the base URL for the A2A # 1. First, fetch the agent card if we haven't already
url = self.agent_card_url try:
agent_card = await self.fetch_agent_card()
# Ensure that there is no /.well-known/agent.json in the url print(f"Agent card fetched: {agent_card.name}")
if "/.well-known/agent.json" in url: except Exception as e:
url = url.split("/.well-known/agent.json")[0] error_msg = f"Failed to fetch agent card: {str(e)}"
print(error_msg)
yield Event(
author=self.name,
content=Content(role="agent", parts=[Part(text=error_msg)]),
)
return
# 2. Extract the user's message from the context # 2. Extract the user's message from the context
user_message = None user_message = None
@ -120,7 +159,16 @@ class A2ACustomAgent(BaseAgent):
elif "message" in ctx.session.state: elif "message" in ctx.session.state:
user_message = ctx.session.state["message"] user_message = ctx.session.state["message"]
# 3. Create and send the task to the A2A agent if not user_message:
error_msg = "No user message found"
print(error_msg)
yield Event(
author=self.name,
content=Content(role="agent", parts=[Part(text=error_msg)]),
)
return
# 3. Create and format the task to send to the A2A agent
print(f"Sending task to A2A agent: {user_message[:100]}...") print(f"Sending task to A2A agent: {user_message[:100]}...")
# Use the session ID as a stable identifier # Use the session ID as a stable identifier
@ -131,46 +179,214 @@ class A2ACustomAgent(BaseAgent):
) )
task_id = str(uuid4()) task_id = str(uuid4())
try: # Prepare the message for the A2A agent
formatted_message = Message(
formatted_message: Message = Message(
role="user", role="user",
parts=[TextPart(type="text", text=user_message)], parts=[TextPart(text=user_message)],
) )
request: SendTaskRequest = SendTaskRequest( # Prepare the task parameters
params={ task_params = TaskSendParams(
"message": formatted_message, id=task_id,
"sessionId": session_id, sessionId=session_id,
"id": task_id, message=formatted_message,
} acceptedOutputModes=["text"],
) )
print(f"Request send task: {request.model_dump()}") # 4. Check if the agent supports streaming
supports_streaming = (
# REQUEST POST to url when jsonrpc is 2.0 agent_card.capabilities.streaming if agent_card.capabilities else False
task_result = await httpx.AsyncClient().post(
url, json=request.model_dump(), timeout=self.timeout
) )
print(f"Task response: {task_result.json()}") if supports_streaming:
print(f"Task sent successfully, ID: {task_id}") print("Agent supports streaming, using streaming API")
# Process with streaming
try:
# Criar a requisição usando o método correto de tasks/sendSubscribe
request = SendTaskStreamingRequest(
method="tasks/sendSubscribe", params=task_params
)
agent_response_parts = task_result.json()["result"]["status"][ try:
"message" async with httpx.AsyncClient() as client:
]["parts"] response = await client.post(
self.base_url,
json=request.model_dump(),
headers={"Accept": "text/event-stream"},
timeout=self.timeout,
)
response.raise_for_status()
parts = [Part(text=part["text"]) for part in agent_response_parts] # Processar manualmente a resposta SSE
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line[5:].strip()
if data:
try:
result = json.loads(data)
print(f"Stream event received: {result}")
# Check if this is a status update with a message
if (
"result" in result
and "status" in result["result"]
and "message"
in result["result"]["status"]
and "parts"
in result["result"]["status"]["message"]
):
message_parts = result["result"][
"status"
]["message"]["parts"]
parts = [
Part(text=part["text"])
for part in message_parts
if part.get("type") == "text"
and "text" in part
]
if parts:
yield Event(
author=self.name,
content=Content(
role="agent", parts=parts
),
)
# Check if this is a final message
if (
"result" in result
and result.get("result", {}).get(
"final", False
)
and "status" in result["result"]
and result["result"]["status"].get(
"state"
)
in [
TaskState.COMPLETED,
TaskState.CANCELED,
TaskState.FAILED,
]
):
print(
"Received final message, stream complete"
)
break
except json.JSONDecodeError as e:
print(f"Error parsing SSE data: {str(e)}")
except Exception as stream_error:
print(
f"Error in direct streaming: {str(stream_error)}, falling back to regular API"
)
# If streaming fails, fall back to regular API
# Criar a requisição usando o método correto de tasks/send
fallback_request = SendTaskRequest(
method="tasks/send", params=task_params
)
async with httpx.AsyncClient() as client:
response = await client.post(
self.base_url,
json=fallback_request.model_dump(),
timeout=self.timeout,
)
response.raise_for_status()
result = response.json()
print(f"Fallback response: {result}")
# Extract agent message parts
if (
"result" in result
and "status" in result["result"]
and "message" in result["result"]["status"]
and "parts" in result["result"]["status"]["message"]
):
message_parts = result["result"]["status"]["message"][
"parts"
]
parts = [
Part(text=part["text"])
for part in message_parts
if part.get("type") == "text" and "text" in part
]
if parts:
yield Event( yield Event(
author=self.name, author=self.name,
content=Content(role="agent", parts=parts), content=Content(role="agent", parts=parts),
) )
else:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(
text="Received response without message parts"
)
],
),
)
except Exception as e:
error_msg = f"Error in streaming: {str(e)}"
print(error_msg)
yield Event(
author=self.name,
content=Content(role="agent", parts=[Part(text=error_msg)]),
)
else:
print("Agent does not support streaming, using regular API")
# Process with regular request
try:
# Criar a requisição usando o método correto de tasks/send
request = SendTaskRequest(method="tasks/send", params=task_params)
# Run sub-agents async with httpx.AsyncClient() as client:
for sub_agent in self.sub_agents: response = await client.post(
async for event in sub_agent.run_async(ctx): self.base_url,
yield event json=request.model_dump(),
timeout=self.timeout,
)
response.raise_for_status()
result = response.json()
print(f"Task response: {result}")
# Extract agent message parts
if (
"result" in result
and "status" in result["result"]
and "message" in result["result"]["status"]
and "parts" in result["result"]["status"]["message"]
):
message_parts = result["result"]["status"]["message"][
"parts"
]
parts = [
Part(text=part["text"])
for part in message_parts
if part.get("type") == "text" and "text" in part
]
if parts:
yield Event(
author=self.name,
content=Content(role="agent", parts=parts),
)
else:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(
text="Received response without message parts"
)
],
),
)
except Exception as e: except Exception as e:
error_msg = f"Error sending request: {str(e)}" error_msg = f"Error sending request: {str(e)}"
@ -182,15 +398,20 @@ class A2ACustomAgent(BaseAgent):
author=self.name, author=self.name,
content=Content(role="agent", parts=[Part(text=error_msg)]), content=Content(role="agent", parts=[Part(text=error_msg)]),
) )
return
# Run sub-agents
for sub_agent in self.sub_agents:
async for event in sub_agent.run_async(ctx):
yield event
except Exception as e: except Exception as e:
# Handle any uncaught error # Handle any uncaught error
print(f"Error executing A2A agent: {str(e)}") error_msg = f"Error executing A2A agent: {str(e)}"
print(error_msg)
yield Event( yield Event(
author=self.name, author=self.name,
content=Content( content=Content(
role="agent", role="agent",
parts=[Part(text=f"Error interacting with A2A agent: {str(e)}")], parts=[Part(text=error_msg)],
), ),
) )