refactor(agent_service): sanitize agent names and improve agent card fetching
This commit is contained in:
parent
0ca6b4f3e9
commit
3622260c11
@ -29,7 +29,7 @@
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Literal
|
||||
from typing import Annotated, Any, Literal, Union, Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
from typing_extensions import Self
|
||||
|
||||
|
@ -103,6 +103,14 @@ def get_agent(db: Session, agent_id: Union[uuid.UUID, str]) -> Optional[Agent]:
|
||||
logger.warning(f"Agent not found: {agent_id}")
|
||||
return None
|
||||
|
||||
# Sanitize agent name if it contains spaces or special characters
|
||||
if agent.name and any(c for c in agent.name if not (c.isalnum() or c == "_")):
|
||||
agent.name = "".join(
|
||||
c if c.isalnum() or c == "_" else "_" for c in agent.name
|
||||
)
|
||||
# Update in database
|
||||
db.commit()
|
||||
|
||||
return agent
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error searching for agent {agent_id}: {str(e)}")
|
||||
@ -144,6 +152,17 @@ def get_agents_by_client(
|
||||
|
||||
agents = query.offset(skip).limit(limit).all()
|
||||
|
||||
# Sanitize agent names if they contain spaces or special characters
|
||||
for agent in agents:
|
||||
if agent.name and any(
|
||||
c for c in agent.name if not (c.isalnum() or c == "_")
|
||||
):
|
||||
agent.name = "".join(
|
||||
c if c.isalnum() or c == "_" else "_" for c in agent.name
|
||||
)
|
||||
# Update in database
|
||||
db.commit()
|
||||
|
||||
return agents
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error searching for client agents {client_id}: {str(e)}")
|
||||
@ -176,7 +195,15 @@ async def create_agent(db: Session, agent: AgentCreate) -> Agent:
|
||||
agent_card = response.json()
|
||||
|
||||
# Update agent with information from agent card
|
||||
agent.name = agent_card.get("name", "Unknown Agent")
|
||||
# Only update name if not provided or empty, or sanitize it
|
||||
if not agent.name or agent.name.strip() == "":
|
||||
# Sanitize name: remove spaces and special characters
|
||||
card_name = agent_card.get("name", "Unknown Agent")
|
||||
sanitized_name = "".join(
|
||||
c if c.isalnum() or c == "_" else "_" for c in card_name
|
||||
)
|
||||
agent.name = sanitized_name
|
||||
|
||||
agent.description = agent_card.get("description", "")
|
||||
|
||||
if agent.config is None:
|
||||
@ -499,7 +526,14 @@ async def update_agent(
|
||||
)
|
||||
agent_card = response.json()
|
||||
|
||||
agent_data["name"] = agent_card.get("name", "Unknown Agent")
|
||||
# Only update name if the original update doesn't specify a name
|
||||
if "name" not in agent_data or not agent_data["name"].strip():
|
||||
# Sanitize name: remove spaces and special characters
|
||||
card_name = agent_card.get("name", "Unknown Agent")
|
||||
sanitized_name = "".join(
|
||||
c if c.isalnum() or c == "_" else "_" for c in card_name
|
||||
)
|
||||
agent_data["name"] = sanitized_name
|
||||
agent_data["description"] = agent_card.get("description", "")
|
||||
|
||||
if "config" not in agent_data or agent_data["config"] is None:
|
||||
@ -537,7 +571,14 @@ async def update_agent(
|
||||
)
|
||||
agent_card = response.json()
|
||||
|
||||
agent_data["name"] = agent_card.get("name", "Unknown Agent")
|
||||
# Only update name if the original update doesn't specify a name
|
||||
if "name" not in agent_data or not agent_data["name"].strip():
|
||||
# Sanitize name: remove spaces and special characters
|
||||
card_name = agent_card.get("name", "Unknown Agent")
|
||||
sanitized_name = "".join(
|
||||
c if c.isalnum() or c == "_" else "_" for c in card_name
|
||||
)
|
||||
agent_data["name"] = sanitized_name
|
||||
agent_data["description"] = agent_card.get("description", "")
|
||||
|
||||
if "config" not in agent_data or agent_data["config"] is None:
|
||||
|
@ -32,15 +32,21 @@ from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.events import Event
|
||||
from google.genai.types import Content, Part
|
||||
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
from src.schemas.a2a_types import (
|
||||
SendTaskRequest,
|
||||
Message,
|
||||
TextPart,
|
||||
)
|
||||
from typing import AsyncGenerator, List, Dict, Any, Optional
|
||||
import json
|
||||
|
||||
import httpx
|
||||
from httpx_sse import connect_sse
|
||||
|
||||
from src.schemas.a2a_types import (
|
||||
AgentCard,
|
||||
Message,
|
||||
TextPart,
|
||||
TaskSendParams,
|
||||
SendTaskRequest,
|
||||
SendTaskStreamingRequest,
|
||||
TaskState,
|
||||
)
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
@ -54,7 +60,9 @@ class A2ACustomAgent(BaseAgent):
|
||||
|
||||
# Field declarations for Pydantic
|
||||
agent_card_url: str
|
||||
agent_card: Optional[AgentCard]
|
||||
timeout: int
|
||||
base_url: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -73,16 +81,41 @@ class A2ACustomAgent(BaseAgent):
|
||||
timeout: Maximum execution time (seconds)
|
||||
sub_agents: List of sub-agents to be executed after the A2A agent
|
||||
"""
|
||||
# Create base_url from agent_card_url
|
||||
base_url = agent_card_url
|
||||
if "/.well-known/agent.json" in base_url:
|
||||
base_url = base_url.split("/.well-known/agent.json")[0]
|
||||
|
||||
print(f"A2A agent initialized for URL: {agent_card_url}")
|
||||
|
||||
# Initialize base class
|
||||
super().__init__(
|
||||
name=name,
|
||||
agent_card_url=agent_card_url,
|
||||
base_url=base_url, # Pass base_url here
|
||||
agent_card=None,
|
||||
timeout=timeout,
|
||||
sub_agents=sub_agents,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
print(f"A2A agent initialized for URL: {agent_card_url}")
|
||||
async def fetch_agent_card(self) -> AgentCard:
|
||||
"""Fetch the agent card from the A2A service."""
|
||||
if self.agent_card:
|
||||
return self.agent_card
|
||||
|
||||
card_url = f"{self.base_url}/.well-known/agent.json"
|
||||
print(f"Fetching agent card from: {card_url}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(card_url)
|
||||
response.raise_for_status()
|
||||
try:
|
||||
card_data = response.json()
|
||||
self.agent_card = AgentCard(**card_data)
|
||||
return self.agent_card
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse agent card: {str(e)}")
|
||||
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
@ -95,12 +128,18 @@ class A2ACustomAgent(BaseAgent):
|
||||
"""
|
||||
|
||||
try:
|
||||
# Prepare the base URL for the A2A
|
||||
url = self.agent_card_url
|
||||
|
||||
# Ensure that there is no /.well-known/agent.json in the url
|
||||
if "/.well-known/agent.json" in url:
|
||||
url = url.split("/.well-known/agent.json")[0]
|
||||
# 1. First, fetch the agent card if we haven't already
|
||||
try:
|
||||
agent_card = await self.fetch_agent_card()
|
||||
print(f"Agent card fetched: {agent_card.name}")
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to fetch agent card: {str(e)}"
|
||||
print(error_msg)
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(role="agent", parts=[Part(text=error_msg)]),
|
||||
)
|
||||
return
|
||||
|
||||
# 2. Extract the user's message from the context
|
||||
user_message = None
|
||||
@ -120,7 +159,16 @@ class A2ACustomAgent(BaseAgent):
|
||||
elif "message" in ctx.session.state:
|
||||
user_message = ctx.session.state["message"]
|
||||
|
||||
# 3. Create and send the task to the A2A agent
|
||||
if not user_message:
|
||||
error_msg = "No user message found"
|
||||
print(error_msg)
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(role="agent", parts=[Part(text=error_msg)]),
|
||||
)
|
||||
return
|
||||
|
||||
# 3. Create and format the task to send to the A2A agent
|
||||
print(f"Sending task to A2A agent: {user_message[:100]}...")
|
||||
|
||||
# Use the session ID as a stable identifier
|
||||
@ -131,66 +179,239 @@ class A2ACustomAgent(BaseAgent):
|
||||
)
|
||||
task_id = str(uuid4())
|
||||
|
||||
try:
|
||||
# Prepare the message for the A2A agent
|
||||
formatted_message = Message(
|
||||
role="user",
|
||||
parts=[TextPart(text=user_message)],
|
||||
)
|
||||
|
||||
formatted_message: Message = Message(
|
||||
role="user",
|
||||
parts=[TextPart(type="text", text=user_message)],
|
||||
)
|
||||
# Prepare the task parameters
|
||||
task_params = TaskSendParams(
|
||||
id=task_id,
|
||||
sessionId=session_id,
|
||||
message=formatted_message,
|
||||
acceptedOutputModes=["text"],
|
||||
)
|
||||
|
||||
request: SendTaskRequest = SendTaskRequest(
|
||||
params={
|
||||
"message": formatted_message,
|
||||
"sessionId": session_id,
|
||||
"id": task_id,
|
||||
}
|
||||
)
|
||||
# 4. Check if the agent supports streaming
|
||||
supports_streaming = (
|
||||
agent_card.capabilities.streaming if agent_card.capabilities else False
|
||||
)
|
||||
|
||||
print(f"Request send task: {request.model_dump()}")
|
||||
if supports_streaming:
|
||||
print("Agent supports streaming, using streaming API")
|
||||
# Process with streaming
|
||||
try:
|
||||
# Criar a requisição usando o método correto de tasks/sendSubscribe
|
||||
request = SendTaskStreamingRequest(
|
||||
method="tasks/sendSubscribe", params=task_params
|
||||
)
|
||||
|
||||
# REQUEST POST to url when jsonrpc is 2.0
|
||||
task_result = await httpx.AsyncClient().post(
|
||||
url, json=request.model_dump(), timeout=self.timeout
|
||||
)
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.base_url,
|
||||
json=request.model_dump(),
|
||||
headers={"Accept": "text/event-stream"},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
print(f"Task response: {task_result.json()}")
|
||||
print(f"Task sent successfully, ID: {task_id}")
|
||||
# Processar manualmente a resposta SSE
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
data = line[5:].strip()
|
||||
if data:
|
||||
try:
|
||||
result = json.loads(data)
|
||||
print(f"Stream event received: {result}")
|
||||
|
||||
agent_response_parts = task_result.json()["result"]["status"][
|
||||
"message"
|
||||
]["parts"]
|
||||
# Check if this is a status update with a message
|
||||
if (
|
||||
"result" in result
|
||||
and "status" in result["result"]
|
||||
and "message"
|
||||
in result["result"]["status"]
|
||||
and "parts"
|
||||
in result["result"]["status"]["message"]
|
||||
):
|
||||
message_parts = result["result"][
|
||||
"status"
|
||||
]["message"]["parts"]
|
||||
parts = [
|
||||
Part(text=part["text"])
|
||||
for part in message_parts
|
||||
if part.get("type") == "text"
|
||||
and "text" in part
|
||||
]
|
||||
|
||||
parts = [Part(text=part["text"]) for part in agent_response_parts]
|
||||
if parts:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
role="agent", parts=parts
|
||||
),
|
||||
)
|
||||
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(role="agent", parts=parts),
|
||||
)
|
||||
# Check if this is a final message
|
||||
if (
|
||||
"result" in result
|
||||
and result.get("result", {}).get(
|
||||
"final", False
|
||||
)
|
||||
and "status" in result["result"]
|
||||
and result["result"]["status"].get(
|
||||
"state"
|
||||
)
|
||||
in [
|
||||
TaskState.COMPLETED,
|
||||
TaskState.CANCELED,
|
||||
TaskState.FAILED,
|
||||
]
|
||||
):
|
||||
print(
|
||||
"Received final message, stream complete"
|
||||
)
|
||||
break
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing SSE data: {str(e)}")
|
||||
except Exception as stream_error:
|
||||
print(
|
||||
f"Error in direct streaming: {str(stream_error)}, falling back to regular API"
|
||||
)
|
||||
# If streaming fails, fall back to regular API
|
||||
# Criar a requisição usando o método correto de tasks/send
|
||||
fallback_request = SendTaskRequest(
|
||||
method="tasks/send", params=task_params
|
||||
)
|
||||
|
||||
# Run sub-agents
|
||||
for sub_agent in self.sub_agents:
|
||||
async for event in sub_agent.run_async(ctx):
|
||||
yield event
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.base_url,
|
||||
json=fallback_request.model_dump(),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error sending request: {str(e)}"
|
||||
print(error_msg)
|
||||
print(f"Error type: {type(e).__name__}")
|
||||
print(f"Error details: {str(e)}")
|
||||
result = response.json()
|
||||
print(f"Fallback response: {result}")
|
||||
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(role="agent", parts=[Part(text=error_msg)]),
|
||||
)
|
||||
return
|
||||
# Extract agent message parts
|
||||
if (
|
||||
"result" in result
|
||||
and "status" in result["result"]
|
||||
and "message" in result["result"]["status"]
|
||||
and "parts" in result["result"]["status"]["message"]
|
||||
):
|
||||
message_parts = result["result"]["status"]["message"][
|
||||
"parts"
|
||||
]
|
||||
parts = [
|
||||
Part(text=part["text"])
|
||||
for part in message_parts
|
||||
if part.get("type") == "text" and "text" in part
|
||||
]
|
||||
|
||||
if parts:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(role="agent", parts=parts),
|
||||
)
|
||||
else:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
role="agent",
|
||||
parts=[
|
||||
Part(
|
||||
text="Received response without message parts"
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in streaming: {str(e)}"
|
||||
print(error_msg)
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(role="agent", parts=[Part(text=error_msg)]),
|
||||
)
|
||||
else:
|
||||
print("Agent does not support streaming, using regular API")
|
||||
# Process with regular request
|
||||
try:
|
||||
# Criar a requisição usando o método correto de tasks/send
|
||||
request = SendTaskRequest(method="tasks/send", params=task_params)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.base_url,
|
||||
json=request.model_dump(),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
print(f"Task response: {result}")
|
||||
|
||||
# Extract agent message parts
|
||||
if (
|
||||
"result" in result
|
||||
and "status" in result["result"]
|
||||
and "message" in result["result"]["status"]
|
||||
and "parts" in result["result"]["status"]["message"]
|
||||
):
|
||||
message_parts = result["result"]["status"]["message"][
|
||||
"parts"
|
||||
]
|
||||
parts = [
|
||||
Part(text=part["text"])
|
||||
for part in message_parts
|
||||
if part.get("type") == "text" and "text" in part
|
||||
]
|
||||
|
||||
if parts:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(role="agent", parts=parts),
|
||||
)
|
||||
else:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
role="agent",
|
||||
parts=[
|
||||
Part(
|
||||
text="Received response without message parts"
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error sending request: {str(e)}"
|
||||
print(error_msg)
|
||||
print(f"Error type: {type(e).__name__}")
|
||||
print(f"Error details: {str(e)}")
|
||||
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(role="agent", parts=[Part(text=error_msg)]),
|
||||
)
|
||||
|
||||
# Run sub-agents
|
||||
for sub_agent in self.sub_agents:
|
||||
async for event in sub_agent.run_async(ctx):
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
# Handle any uncaught error
|
||||
print(f"Error executing A2A agent: {str(e)}")
|
||||
error_msg = f"Error executing A2A agent: {str(e)}"
|
||||
print(error_msg)
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
role="agent",
|
||||
parts=[Part(text=f"Error interacting with A2A agent: {str(e)}")],
|
||||
parts=[Part(text=error_msg)],
|
||||
),
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user