refactor(agent_service): sanitize agent names and improve agent card fetching
This commit is contained in:
parent
0ca6b4f3e9
commit
3622260c11
@ -29,7 +29,7 @@
|
|||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal, Union, Dict, List, Optional
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
@ -103,6 +103,14 @@ def get_agent(db: Session, agent_id: Union[uuid.UUID, str]) -> Optional[Agent]:
|
|||||||
logger.warning(f"Agent not found: {agent_id}")
|
logger.warning(f"Agent not found: {agent_id}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Sanitize agent name if it contains spaces or special characters
|
||||||
|
if agent.name and any(c for c in agent.name if not (c.isalnum() or c == "_")):
|
||||||
|
agent.name = "".join(
|
||||||
|
c if c.isalnum() or c == "_" else "_" for c in agent.name
|
||||||
|
)
|
||||||
|
# Update in database
|
||||||
|
db.commit()
|
||||||
|
|
||||||
return agent
|
return agent
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
logger.error(f"Error searching for agent {agent_id}: {str(e)}")
|
logger.error(f"Error searching for agent {agent_id}: {str(e)}")
|
||||||
@ -144,6 +152,17 @@ def get_agents_by_client(
|
|||||||
|
|
||||||
agents = query.offset(skip).limit(limit).all()
|
agents = query.offset(skip).limit(limit).all()
|
||||||
|
|
||||||
|
# Sanitize agent names if they contain spaces or special characters
|
||||||
|
for agent in agents:
|
||||||
|
if agent.name and any(
|
||||||
|
c for c in agent.name if not (c.isalnum() or c == "_")
|
||||||
|
):
|
||||||
|
agent.name = "".join(
|
||||||
|
c if c.isalnum() or c == "_" else "_" for c in agent.name
|
||||||
|
)
|
||||||
|
# Update in database
|
||||||
|
db.commit()
|
||||||
|
|
||||||
return agents
|
return agents
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
logger.error(f"Error searching for client agents {client_id}: {str(e)}")
|
logger.error(f"Error searching for client agents {client_id}: {str(e)}")
|
||||||
@ -176,7 +195,15 @@ async def create_agent(db: Session, agent: AgentCreate) -> Agent:
|
|||||||
agent_card = response.json()
|
agent_card = response.json()
|
||||||
|
|
||||||
# Update agent with information from agent card
|
# Update agent with information from agent card
|
||||||
agent.name = agent_card.get("name", "Unknown Agent")
|
# Only update name if not provided or empty, or sanitize it
|
||||||
|
if not agent.name or agent.name.strip() == "":
|
||||||
|
# Sanitize name: remove spaces and special characters
|
||||||
|
card_name = agent_card.get("name", "Unknown Agent")
|
||||||
|
sanitized_name = "".join(
|
||||||
|
c if c.isalnum() or c == "_" else "_" for c in card_name
|
||||||
|
)
|
||||||
|
agent.name = sanitized_name
|
||||||
|
|
||||||
agent.description = agent_card.get("description", "")
|
agent.description = agent_card.get("description", "")
|
||||||
|
|
||||||
if agent.config is None:
|
if agent.config is None:
|
||||||
@ -499,7 +526,14 @@ async def update_agent(
|
|||||||
)
|
)
|
||||||
agent_card = response.json()
|
agent_card = response.json()
|
||||||
|
|
||||||
agent_data["name"] = agent_card.get("name", "Unknown Agent")
|
# Only update name if the original update doesn't specify a name
|
||||||
|
if "name" not in agent_data or not agent_data["name"].strip():
|
||||||
|
# Sanitize name: remove spaces and special characters
|
||||||
|
card_name = agent_card.get("name", "Unknown Agent")
|
||||||
|
sanitized_name = "".join(
|
||||||
|
c if c.isalnum() or c == "_" else "_" for c in card_name
|
||||||
|
)
|
||||||
|
agent_data["name"] = sanitized_name
|
||||||
agent_data["description"] = agent_card.get("description", "")
|
agent_data["description"] = agent_card.get("description", "")
|
||||||
|
|
||||||
if "config" not in agent_data or agent_data["config"] is None:
|
if "config" not in agent_data or agent_data["config"] is None:
|
||||||
@ -537,7 +571,14 @@ async def update_agent(
|
|||||||
)
|
)
|
||||||
agent_card = response.json()
|
agent_card = response.json()
|
||||||
|
|
||||||
agent_data["name"] = agent_card.get("name", "Unknown Agent")
|
# Only update name if the original update doesn't specify a name
|
||||||
|
if "name" not in agent_data or not agent_data["name"].strip():
|
||||||
|
# Sanitize name: remove spaces and special characters
|
||||||
|
card_name = agent_card.get("name", "Unknown Agent")
|
||||||
|
sanitized_name = "".join(
|
||||||
|
c if c.isalnum() or c == "_" else "_" for c in card_name
|
||||||
|
)
|
||||||
|
agent_data["name"] = sanitized_name
|
||||||
agent_data["description"] = agent_card.get("description", "")
|
agent_data["description"] = agent_card.get("description", "")
|
||||||
|
|
||||||
if "config" not in agent_data or agent_data["config"] is None:
|
if "config" not in agent_data or agent_data["config"] is None:
|
||||||
|
@ -32,15 +32,21 @@ from google.adk.agents.invocation_context import InvocationContext
|
|||||||
from google.adk.events import Event
|
from google.adk.events import Event
|
||||||
from google.genai.types import Content, Part
|
from google.genai.types import Content, Part
|
||||||
|
|
||||||
from typing import AsyncGenerator, List
|
from typing import AsyncGenerator, List, Dict, Any, Optional
|
||||||
|
import json
|
||||||
from src.schemas.a2a_types import (
|
|
||||||
SendTaskRequest,
|
|
||||||
Message,
|
|
||||||
TextPart,
|
|
||||||
)
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from httpx_sse import connect_sse
|
||||||
|
|
||||||
|
from src.schemas.a2a_types import (
|
||||||
|
AgentCard,
|
||||||
|
Message,
|
||||||
|
TextPart,
|
||||||
|
TaskSendParams,
|
||||||
|
SendTaskRequest,
|
||||||
|
SendTaskStreamingRequest,
|
||||||
|
TaskState,
|
||||||
|
)
|
||||||
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
@ -54,7 +60,9 @@ class A2ACustomAgent(BaseAgent):
|
|||||||
|
|
||||||
# Field declarations for Pydantic
|
# Field declarations for Pydantic
|
||||||
agent_card_url: str
|
agent_card_url: str
|
||||||
|
agent_card: Optional[AgentCard]
|
||||||
timeout: int
|
timeout: int
|
||||||
|
base_url: str
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -73,16 +81,41 @@ class A2ACustomAgent(BaseAgent):
|
|||||||
timeout: Maximum execution time (seconds)
|
timeout: Maximum execution time (seconds)
|
||||||
sub_agents: List of sub-agents to be executed after the A2A agent
|
sub_agents: List of sub-agents to be executed after the A2A agent
|
||||||
"""
|
"""
|
||||||
|
# Create base_url from agent_card_url
|
||||||
|
base_url = agent_card_url
|
||||||
|
if "/.well-known/agent.json" in base_url:
|
||||||
|
base_url = base_url.split("/.well-known/agent.json")[0]
|
||||||
|
|
||||||
|
print(f"A2A agent initialized for URL: {agent_card_url}")
|
||||||
|
|
||||||
# Initialize base class
|
# Initialize base class
|
||||||
super().__init__(
|
super().__init__(
|
||||||
name=name,
|
name=name,
|
||||||
agent_card_url=agent_card_url,
|
agent_card_url=agent_card_url,
|
||||||
|
base_url=base_url, # Pass base_url here
|
||||||
|
agent_card=None,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
sub_agents=sub_agents,
|
sub_agents=sub_agents,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"A2A agent initialized for URL: {agent_card_url}")
|
async def fetch_agent_card(self) -> AgentCard:
|
||||||
|
"""Fetch the agent card from the A2A service."""
|
||||||
|
if self.agent_card:
|
||||||
|
return self.agent_card
|
||||||
|
|
||||||
|
card_url = f"{self.base_url}/.well-known/agent.json"
|
||||||
|
print(f"Fetching agent card from: {card_url}")
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(card_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
try:
|
||||||
|
card_data = response.json()
|
||||||
|
self.agent_card = AgentCard(**card_data)
|
||||||
|
return self.agent_card
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"Failed to parse agent card: {str(e)}")
|
||||||
|
|
||||||
async def _run_async_impl(
|
async def _run_async_impl(
|
||||||
self, ctx: InvocationContext
|
self, ctx: InvocationContext
|
||||||
@ -95,12 +128,18 @@ class A2ACustomAgent(BaseAgent):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Prepare the base URL for the A2A
|
# 1. First, fetch the agent card if we haven't already
|
||||||
url = self.agent_card_url
|
try:
|
||||||
|
agent_card = await self.fetch_agent_card()
|
||||||
# Ensure that there is no /.well-known/agent.json in the url
|
print(f"Agent card fetched: {agent_card.name}")
|
||||||
if "/.well-known/agent.json" in url:
|
except Exception as e:
|
||||||
url = url.split("/.well-known/agent.json")[0]
|
error_msg = f"Failed to fetch agent card: {str(e)}"
|
||||||
|
print(error_msg)
|
||||||
|
yield Event(
|
||||||
|
author=self.name,
|
||||||
|
content=Content(role="agent", parts=[Part(text=error_msg)]),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# 2. Extract the user's message from the context
|
# 2. Extract the user's message from the context
|
||||||
user_message = None
|
user_message = None
|
||||||
@ -120,7 +159,16 @@ class A2ACustomAgent(BaseAgent):
|
|||||||
elif "message" in ctx.session.state:
|
elif "message" in ctx.session.state:
|
||||||
user_message = ctx.session.state["message"]
|
user_message = ctx.session.state["message"]
|
||||||
|
|
||||||
# 3. Create and send the task to the A2A agent
|
if not user_message:
|
||||||
|
error_msg = "No user message found"
|
||||||
|
print(error_msg)
|
||||||
|
yield Event(
|
||||||
|
author=self.name,
|
||||||
|
content=Content(role="agent", parts=[Part(text=error_msg)]),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 3. Create and format the task to send to the A2A agent
|
||||||
print(f"Sending task to A2A agent: {user_message[:100]}...")
|
print(f"Sending task to A2A agent: {user_message[:100]}...")
|
||||||
|
|
||||||
# Use the session ID as a stable identifier
|
# Use the session ID as a stable identifier
|
||||||
@ -131,46 +179,214 @@ class A2ACustomAgent(BaseAgent):
|
|||||||
)
|
)
|
||||||
task_id = str(uuid4())
|
task_id = str(uuid4())
|
||||||
|
|
||||||
try:
|
# Prepare the message for the A2A agent
|
||||||
|
formatted_message = Message(
|
||||||
formatted_message: Message = Message(
|
|
||||||
role="user",
|
role="user",
|
||||||
parts=[TextPart(type="text", text=user_message)],
|
parts=[TextPart(text=user_message)],
|
||||||
)
|
)
|
||||||
|
|
||||||
request: SendTaskRequest = SendTaskRequest(
|
# Prepare the task parameters
|
||||||
params={
|
task_params = TaskSendParams(
|
||||||
"message": formatted_message,
|
id=task_id,
|
||||||
"sessionId": session_id,
|
sessionId=session_id,
|
||||||
"id": task_id,
|
message=formatted_message,
|
||||||
}
|
acceptedOutputModes=["text"],
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Request send task: {request.model_dump()}")
|
# 4. Check if the agent supports streaming
|
||||||
|
supports_streaming = (
|
||||||
# REQUEST POST to url when jsonrpc is 2.0
|
agent_card.capabilities.streaming if agent_card.capabilities else False
|
||||||
task_result = await httpx.AsyncClient().post(
|
|
||||||
url, json=request.model_dump(), timeout=self.timeout
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Task response: {task_result.json()}")
|
if supports_streaming:
|
||||||
print(f"Task sent successfully, ID: {task_id}")
|
print("Agent supports streaming, using streaming API")
|
||||||
|
# Process with streaming
|
||||||
|
try:
|
||||||
|
# Criar a requisição usando o método correto de tasks/sendSubscribe
|
||||||
|
request = SendTaskStreamingRequest(
|
||||||
|
method="tasks/sendSubscribe", params=task_params
|
||||||
|
)
|
||||||
|
|
||||||
agent_response_parts = task_result.json()["result"]["status"][
|
try:
|
||||||
"message"
|
async with httpx.AsyncClient() as client:
|
||||||
]["parts"]
|
response = await client.post(
|
||||||
|
self.base_url,
|
||||||
|
json=request.model_dump(),
|
||||||
|
headers={"Accept": "text/event-stream"},
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
parts = [Part(text=part["text"]) for part in agent_response_parts]
|
# Processar manualmente a resposta SSE
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data:"):
|
||||||
|
data = line[5:].strip()
|
||||||
|
if data:
|
||||||
|
try:
|
||||||
|
result = json.loads(data)
|
||||||
|
print(f"Stream event received: {result}")
|
||||||
|
|
||||||
|
# Check if this is a status update with a message
|
||||||
|
if (
|
||||||
|
"result" in result
|
||||||
|
and "status" in result["result"]
|
||||||
|
and "message"
|
||||||
|
in result["result"]["status"]
|
||||||
|
and "parts"
|
||||||
|
in result["result"]["status"]["message"]
|
||||||
|
):
|
||||||
|
message_parts = result["result"][
|
||||||
|
"status"
|
||||||
|
]["message"]["parts"]
|
||||||
|
parts = [
|
||||||
|
Part(text=part["text"])
|
||||||
|
for part in message_parts
|
||||||
|
if part.get("type") == "text"
|
||||||
|
and "text" in part
|
||||||
|
]
|
||||||
|
|
||||||
|
if parts:
|
||||||
|
yield Event(
|
||||||
|
author=self.name,
|
||||||
|
content=Content(
|
||||||
|
role="agent", parts=parts
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if this is a final message
|
||||||
|
if (
|
||||||
|
"result" in result
|
||||||
|
and result.get("result", {}).get(
|
||||||
|
"final", False
|
||||||
|
)
|
||||||
|
and "status" in result["result"]
|
||||||
|
and result["result"]["status"].get(
|
||||||
|
"state"
|
||||||
|
)
|
||||||
|
in [
|
||||||
|
TaskState.COMPLETED,
|
||||||
|
TaskState.CANCELED,
|
||||||
|
TaskState.FAILED,
|
||||||
|
]
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
"Received final message, stream complete"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Error parsing SSE data: {str(e)}")
|
||||||
|
except Exception as stream_error:
|
||||||
|
print(
|
||||||
|
f"Error in direct streaming: {str(stream_error)}, falling back to regular API"
|
||||||
|
)
|
||||||
|
# If streaming fails, fall back to regular API
|
||||||
|
# Criar a requisição usando o método correto de tasks/send
|
||||||
|
fallback_request = SendTaskRequest(
|
||||||
|
method="tasks/send", params=task_params
|
||||||
|
)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
self.base_url,
|
||||||
|
json=fallback_request.model_dump(),
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
print(f"Fallback response: {result}")
|
||||||
|
|
||||||
|
# Extract agent message parts
|
||||||
|
if (
|
||||||
|
"result" in result
|
||||||
|
and "status" in result["result"]
|
||||||
|
and "message" in result["result"]["status"]
|
||||||
|
and "parts" in result["result"]["status"]["message"]
|
||||||
|
):
|
||||||
|
message_parts = result["result"]["status"]["message"][
|
||||||
|
"parts"
|
||||||
|
]
|
||||||
|
parts = [
|
||||||
|
Part(text=part["text"])
|
||||||
|
for part in message_parts
|
||||||
|
if part.get("type") == "text" and "text" in part
|
||||||
|
]
|
||||||
|
|
||||||
|
if parts:
|
||||||
yield Event(
|
yield Event(
|
||||||
author=self.name,
|
author=self.name,
|
||||||
content=Content(role="agent", parts=parts),
|
content=Content(role="agent", parts=parts),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
yield Event(
|
||||||
|
author=self.name,
|
||||||
|
content=Content(
|
||||||
|
role="agent",
|
||||||
|
parts=[
|
||||||
|
Part(
|
||||||
|
text="Received response without message parts"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error in streaming: {str(e)}"
|
||||||
|
print(error_msg)
|
||||||
|
yield Event(
|
||||||
|
author=self.name,
|
||||||
|
content=Content(role="agent", parts=[Part(text=error_msg)]),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Agent does not support streaming, using regular API")
|
||||||
|
# Process with regular request
|
||||||
|
try:
|
||||||
|
# Criar a requisição usando o método correto de tasks/send
|
||||||
|
request = SendTaskRequest(method="tasks/send", params=task_params)
|
||||||
|
|
||||||
# Run sub-agents
|
async with httpx.AsyncClient() as client:
|
||||||
for sub_agent in self.sub_agents:
|
response = await client.post(
|
||||||
async for event in sub_agent.run_async(ctx):
|
self.base_url,
|
||||||
yield event
|
json=request.model_dump(),
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
print(f"Task response: {result}")
|
||||||
|
|
||||||
|
# Extract agent message parts
|
||||||
|
if (
|
||||||
|
"result" in result
|
||||||
|
and "status" in result["result"]
|
||||||
|
and "message" in result["result"]["status"]
|
||||||
|
and "parts" in result["result"]["status"]["message"]
|
||||||
|
):
|
||||||
|
message_parts = result["result"]["status"]["message"][
|
||||||
|
"parts"
|
||||||
|
]
|
||||||
|
parts = [
|
||||||
|
Part(text=part["text"])
|
||||||
|
for part in message_parts
|
||||||
|
if part.get("type") == "text" and "text" in part
|
||||||
|
]
|
||||||
|
|
||||||
|
if parts:
|
||||||
|
yield Event(
|
||||||
|
author=self.name,
|
||||||
|
content=Content(role="agent", parts=parts),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield Event(
|
||||||
|
author=self.name,
|
||||||
|
content=Content(
|
||||||
|
role="agent",
|
||||||
|
parts=[
|
||||||
|
Part(
|
||||||
|
text="Received response without message parts"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error sending request: {str(e)}"
|
error_msg = f"Error sending request: {str(e)}"
|
||||||
@ -182,15 +398,20 @@ class A2ACustomAgent(BaseAgent):
|
|||||||
author=self.name,
|
author=self.name,
|
||||||
content=Content(role="agent", parts=[Part(text=error_msg)]),
|
content=Content(role="agent", parts=[Part(text=error_msg)]),
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
# Run sub-agents
|
||||||
|
for sub_agent in self.sub_agents:
|
||||||
|
async for event in sub_agent.run_async(ctx):
|
||||||
|
yield event
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle any uncaught error
|
# Handle any uncaught error
|
||||||
print(f"Error executing A2A agent: {str(e)}")
|
error_msg = f"Error executing A2A agent: {str(e)}"
|
||||||
|
print(error_msg)
|
||||||
yield Event(
|
yield Event(
|
||||||
author=self.name,
|
author=self.name,
|
||||||
content=Content(
|
content=Content(
|
||||||
role="agent",
|
role="agent",
|
||||||
parts=[Part(text=f"Error interacting with A2A agent: {str(e)}")],
|
parts=[Part(text=error_msg)],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user