refactor(agent_service): sanitize agent names and improve agent card fetching

This commit is contained in:
Davidson Gomes 2025-05-14 15:10:48 -03:00
parent 0ca6b4f3e9
commit 3622260c11
3 changed files with 324 additions and 62 deletions

View File

@ -29,7 +29,7 @@
from datetime import datetime
from enum import Enum
from typing import Annotated, Any, Literal
from typing import Annotated, Any, Literal, Union, Dict, List, Optional
from uuid import uuid4
from typing_extensions import Self

View File

@ -103,6 +103,14 @@ def get_agent(db: Session, agent_id: Union[uuid.UUID, str]) -> Optional[Agent]:
logger.warning(f"Agent not found: {agent_id}")
return None
# Sanitize agent name if it contains spaces or special characters
if agent.name and any(c for c in agent.name if not (c.isalnum() or c == "_")):
agent.name = "".join(
c if c.isalnum() or c == "_" else "_" for c in agent.name
)
# Update in database
db.commit()
return agent
except SQLAlchemyError as e:
logger.error(f"Error searching for agent {agent_id}: {str(e)}")
@ -144,6 +152,17 @@ def get_agents_by_client(
agents = query.offset(skip).limit(limit).all()
# Sanitize agent names if they contain spaces or special characters
for agent in agents:
if agent.name and any(
c for c in agent.name if not (c.isalnum() or c == "_")
):
agent.name = "".join(
c if c.isalnum() or c == "_" else "_" for c in agent.name
)
# Update in database
db.commit()
return agents
except SQLAlchemyError as e:
logger.error(f"Error searching for client agents {client_id}: {str(e)}")
@ -176,7 +195,15 @@ async def create_agent(db: Session, agent: AgentCreate) -> Agent:
agent_card = response.json()
# Update agent with information from agent card
agent.name = agent_card.get("name", "Unknown Agent")
# Only update name if not provided or empty, or sanitize it
if not agent.name or agent.name.strip() == "":
# Sanitize name: remove spaces and special characters
card_name = agent_card.get("name", "Unknown Agent")
sanitized_name = "".join(
c if c.isalnum() or c == "_" else "_" for c in card_name
)
agent.name = sanitized_name
agent.description = agent_card.get("description", "")
if agent.config is None:
@ -499,7 +526,14 @@ async def update_agent(
)
agent_card = response.json()
agent_data["name"] = agent_card.get("name", "Unknown Agent")
# Only update name if the original update doesn't specify a name
if "name" not in agent_data or not agent_data["name"].strip():
# Sanitize name: remove spaces and special characters
card_name = agent_card.get("name", "Unknown Agent")
sanitized_name = "".join(
c if c.isalnum() or c == "_" else "_" for c in card_name
)
agent_data["name"] = sanitized_name
agent_data["description"] = agent_card.get("description", "")
if "config" not in agent_data or agent_data["config"] is None:
@ -537,7 +571,14 @@ async def update_agent(
)
agent_card = response.json()
agent_data["name"] = agent_card.get("name", "Unknown Agent")
# Only update name if the original update doesn't specify a name
if "name" not in agent_data or not agent_data["name"].strip():
# Sanitize name: remove spaces and special characters
card_name = agent_card.get("name", "Unknown Agent")
sanitized_name = "".join(
c if c.isalnum() or c == "_" else "_" for c in card_name
)
agent_data["name"] = sanitized_name
agent_data["description"] = agent_card.get("description", "")
if "config" not in agent_data or agent_data["config"] is None:

View File

@ -32,15 +32,21 @@ from google.adk.agents.invocation_context import InvocationContext
from google.adk.events import Event
from google.genai.types import Content, Part
from typing import AsyncGenerator, List
from src.schemas.a2a_types import (
SendTaskRequest,
Message,
TextPart,
)
from typing import AsyncGenerator, List, Dict, Any, Optional
import json
import httpx
from httpx_sse import connect_sse
from src.schemas.a2a_types import (
AgentCard,
Message,
TextPart,
TaskSendParams,
SendTaskRequest,
SendTaskStreamingRequest,
TaskState,
)
from uuid import uuid4
@ -54,7 +60,9 @@ class A2ACustomAgent(BaseAgent):
# Field declarations for Pydantic
agent_card_url: str
agent_card: Optional[AgentCard]
timeout: int
base_url: str
def __init__(
self,
@ -73,16 +81,41 @@ class A2ACustomAgent(BaseAgent):
timeout: Maximum execution time (seconds)
sub_agents: List of sub-agents to be executed after the A2A agent
"""
# Create base_url from agent_card_url
base_url = agent_card_url
if "/.well-known/agent.json" in base_url:
base_url = base_url.split("/.well-known/agent.json")[0]
print(f"A2A agent initialized for URL: {agent_card_url}")
# Initialize base class
super().__init__(
name=name,
agent_card_url=agent_card_url,
base_url=base_url, # Pass base_url here
agent_card=None,
timeout=timeout,
sub_agents=sub_agents,
**kwargs,
)
print(f"A2A agent initialized for URL: {agent_card_url}")
async def fetch_agent_card(self) -> AgentCard:
"""Fetch the agent card from the A2A service."""
if self.agent_card:
return self.agent_card
card_url = f"{self.base_url}/.well-known/agent.json"
print(f"Fetching agent card from: {card_url}")
async with httpx.AsyncClient() as client:
response = await client.get(card_url)
response.raise_for_status()
try:
card_data = response.json()
self.agent_card = AgentCard(**card_data)
return self.agent_card
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse agent card: {str(e)}")
async def _run_async_impl(
self, ctx: InvocationContext
@ -95,12 +128,18 @@ class A2ACustomAgent(BaseAgent):
"""
try:
# Prepare the base URL for the A2A
url = self.agent_card_url
# Ensure that there is no /.well-known/agent.json in the url
if "/.well-known/agent.json" in url:
url = url.split("/.well-known/agent.json")[0]
# 1. First, fetch the agent card if we haven't already
try:
agent_card = await self.fetch_agent_card()
print(f"Agent card fetched: {agent_card.name}")
except Exception as e:
error_msg = f"Failed to fetch agent card: {str(e)}"
print(error_msg)
yield Event(
author=self.name,
content=Content(role="agent", parts=[Part(text=error_msg)]),
)
return
# 2. Extract the user's message from the context
user_message = None
@ -120,7 +159,16 @@ class A2ACustomAgent(BaseAgent):
elif "message" in ctx.session.state:
user_message = ctx.session.state["message"]
# 3. Create and send the task to the A2A agent
if not user_message:
error_msg = "No user message found"
print(error_msg)
yield Event(
author=self.name,
content=Content(role="agent", parts=[Part(text=error_msg)]),
)
return
# 3. Create and format the task to send to the A2A agent
print(f"Sending task to A2A agent: {user_message[:100]}...")
# Use the session ID as a stable identifier
@ -131,66 +179,239 @@ class A2ACustomAgent(BaseAgent):
)
task_id = str(uuid4())
try:
# Prepare the message for the A2A agent
formatted_message = Message(
role="user",
parts=[TextPart(text=user_message)],
)
formatted_message: Message = Message(
role="user",
parts=[TextPart(type="text", text=user_message)],
)
# Prepare the task parameters
task_params = TaskSendParams(
id=task_id,
sessionId=session_id,
message=formatted_message,
acceptedOutputModes=["text"],
)
request: SendTaskRequest = SendTaskRequest(
params={
"message": formatted_message,
"sessionId": session_id,
"id": task_id,
}
)
# 4. Check if the agent supports streaming
supports_streaming = (
agent_card.capabilities.streaming if agent_card.capabilities else False
)
print(f"Request send task: {request.model_dump()}")
if supports_streaming:
print("Agent supports streaming, using streaming API")
# Process with streaming
try:
# Criar a requisição usando o método correto de tasks/sendSubscribe
request = SendTaskStreamingRequest(
method="tasks/sendSubscribe", params=task_params
)
# REQUEST POST to url when jsonrpc is 2.0
task_result = await httpx.AsyncClient().post(
url, json=request.model_dump(), timeout=self.timeout
)
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self.base_url,
json=request.model_dump(),
headers={"Accept": "text/event-stream"},
timeout=self.timeout,
)
response.raise_for_status()
print(f"Task response: {task_result.json()}")
print(f"Task sent successfully, ID: {task_id}")
# Processar manualmente a resposta SSE
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line[5:].strip()
if data:
try:
result = json.loads(data)
print(f"Stream event received: {result}")
agent_response_parts = task_result.json()["result"]["status"][
"message"
]["parts"]
# Check if this is a status update with a message
if (
"result" in result
and "status" in result["result"]
and "message"
in result["result"]["status"]
and "parts"
in result["result"]["status"]["message"]
):
message_parts = result["result"][
"status"
]["message"]["parts"]
parts = [
Part(text=part["text"])
for part in message_parts
if part.get("type") == "text"
and "text" in part
]
parts = [Part(text=part["text"]) for part in agent_response_parts]
if parts:
yield Event(
author=self.name,
content=Content(
role="agent", parts=parts
),
)
yield Event(
author=self.name,
content=Content(role="agent", parts=parts),
)
# Check if this is a final message
if (
"result" in result
and result.get("result", {}).get(
"final", False
)
and "status" in result["result"]
and result["result"]["status"].get(
"state"
)
in [
TaskState.COMPLETED,
TaskState.CANCELED,
TaskState.FAILED,
]
):
print(
"Received final message, stream complete"
)
break
except json.JSONDecodeError as e:
print(f"Error parsing SSE data: {str(e)}")
except Exception as stream_error:
print(
f"Error in direct streaming: {str(stream_error)}, falling back to regular API"
)
# If streaming fails, fall back to regular API
# Criar a requisição usando o método correto de tasks/send
fallback_request = SendTaskRequest(
method="tasks/send", params=task_params
)
# Run sub-agents
for sub_agent in self.sub_agents:
async for event in sub_agent.run_async(ctx):
yield event
async with httpx.AsyncClient() as client:
response = await client.post(
self.base_url,
json=fallback_request.model_dump(),
timeout=self.timeout,
)
response.raise_for_status()
except Exception as e:
error_msg = f"Error sending request: {str(e)}"
print(error_msg)
print(f"Error type: {type(e).__name__}")
print(f"Error details: {str(e)}")
result = response.json()
print(f"Fallback response: {result}")
yield Event(
author=self.name,
content=Content(role="agent", parts=[Part(text=error_msg)]),
)
return
# Extract agent message parts
if (
"result" in result
and "status" in result["result"]
and "message" in result["result"]["status"]
and "parts" in result["result"]["status"]["message"]
):
message_parts = result["result"]["status"]["message"][
"parts"
]
parts = [
Part(text=part["text"])
for part in message_parts
if part.get("type") == "text" and "text" in part
]
if parts:
yield Event(
author=self.name,
content=Content(role="agent", parts=parts),
)
else:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(
text="Received response without message parts"
)
],
),
)
except Exception as e:
error_msg = f"Error in streaming: {str(e)}"
print(error_msg)
yield Event(
author=self.name,
content=Content(role="agent", parts=[Part(text=error_msg)]),
)
else:
print("Agent does not support streaming, using regular API")
# Process with regular request
try:
# Criar a requisição usando o método correto de tasks/send
request = SendTaskRequest(method="tasks/send", params=task_params)
async with httpx.AsyncClient() as client:
response = await client.post(
self.base_url,
json=request.model_dump(),
timeout=self.timeout,
)
response.raise_for_status()
result = response.json()
print(f"Task response: {result}")
# Extract agent message parts
if (
"result" in result
and "status" in result["result"]
and "message" in result["result"]["status"]
and "parts" in result["result"]["status"]["message"]
):
message_parts = result["result"]["status"]["message"][
"parts"
]
parts = [
Part(text=part["text"])
for part in message_parts
if part.get("type") == "text" and "text" in part
]
if parts:
yield Event(
author=self.name,
content=Content(role="agent", parts=parts),
)
else:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(
text="Received response without message parts"
)
],
),
)
except Exception as e:
error_msg = f"Error sending request: {str(e)}"
print(error_msg)
print(f"Error type: {type(e).__name__}")
print(f"Error details: {str(e)}")
yield Event(
author=self.name,
content=Content(role="agent", parts=[Part(text=error_msg)]),
)
# Run sub-agents
for sub_agent in self.sub_agents:
async for event in sub_agent.run_async(ctx):
yield event
except Exception as e:
# Handle any uncaught error
print(f"Error executing A2A agent: {str(e)}")
error_msg = f"Error executing A2A agent: {str(e)}"
print(error_msg)
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[Part(text=f"Error interacting with A2A agent: {str(e)}")],
parts=[Part(text=error_msg)],
),
)