refactor(a2a): simplify A2A agent initialization and enhance sub-agent support
This commit is contained in:
parent
16e9747cce
commit
64e483533d
@ -2,28 +2,19 @@ from google.adk.agents import BaseAgent
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.events import Event
|
||||
from google.genai.types import Content, Part
|
||||
from typing import AsyncGenerator, Dict, Any, Optional
|
||||
import json
|
||||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
from src.schemas.a2a_types import (
|
||||
GetTaskRequest,
|
||||
SendTaskRequest,
|
||||
SendTaskStreamingRequest,
|
||||
Message,
|
||||
TextPart,
|
||||
TaskState,
|
||||
JSONRPCRequest,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class A2ACustomAgent(BaseAgent):
|
||||
"""
|
||||
@ -34,22 +25,14 @@ class A2ACustomAgent(BaseAgent):
|
||||
|
||||
# Field declarations for Pydantic
|
||||
agent_card_url: str
|
||||
api_key: Optional[str]
|
||||
poll_interval: float
|
||||
max_wait_time: int
|
||||
timeout: int
|
||||
streaming: bool
|
||||
base_url: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
agent_card_url: str,
|
||||
api_key: Optional[str] = None,
|
||||
poll_interval: float = 1.0,
|
||||
max_wait_time: int = 60,
|
||||
timeout: int = 300,
|
||||
streaming: bool = True,
|
||||
sub_agents: List[BaseAgent] = [],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -58,83 +41,19 @@ class A2ACustomAgent(BaseAgent):
|
||||
Args:
|
||||
name: Agent name
|
||||
agent_card_url: A2A agent card URL
|
||||
api_key: API key for authentication
|
||||
poll_interval: Status check interval (seconds)
|
||||
max_wait_time: Maximum wait time for a task (seconds)
|
||||
timeout: Maximum execution time (seconds)
|
||||
streaming: Whether to use streaming mode
|
||||
sub_agents: List of sub-agents to be executed after the A2A agent
|
||||
"""
|
||||
# Get base URL by removing agent.json if present
|
||||
derived_base_url = agent_card_url
|
||||
if "/.well-known/agent.json" in derived_base_url:
|
||||
derived_base_url = derived_base_url.split("/.well-known/agent.json")[0]
|
||||
|
||||
# Initialize base class with all fields including base_url
|
||||
# Initialize base class
|
||||
super().__init__(
|
||||
name=name,
|
||||
agent_card_url=agent_card_url,
|
||||
api_key=api_key,
|
||||
poll_interval=poll_interval,
|
||||
max_wait_time=max_wait_time,
|
||||
timeout=timeout,
|
||||
streaming=streaming,
|
||||
base_url=derived_base_url,
|
||||
sub_agents=sub_agents,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
logger.info(f"A2A agent initialized for URL: {agent_card_url}")
|
||||
|
||||
# Default headers
|
||||
self.headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
self.headers["x-api-key"] = api_key
|
||||
|
||||
async def _send_jsonrpc_request(self, request: JSONRPCRequest) -> Dict[str, Any]:
|
||||
"""
|
||||
Send a JSON-RPC request to the A2A endpoint.
|
||||
|
||||
Args:
|
||||
request: The JSON-RPC request object
|
||||
|
||||
Returns:
|
||||
Dict containing the response
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
logger.debug(
|
||||
f"Sending request to {self.base_url}: {request.model_dump()}"
|
||||
)
|
||||
response = await client.post(
|
||||
self.base_url,
|
||||
json=request.model_dump(),
|
||||
headers=self.headers,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
logger.debug(f"Received response: {response_data}")
|
||||
|
||||
# Check for JSON-RPC errors
|
||||
if "error" in response_data and response_data["error"]:
|
||||
error_msg = response_data["error"].get("message", "Unknown error")
|
||||
error_code = response_data["error"].get("code", -1)
|
||||
logger.error(f"JSON-RPC error {error_code}: {error_msg}")
|
||||
raise ValueError(f"A2A server error: {error_msg}")
|
||||
|
||||
return response_data
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error: {e.response.status_code} - {e}")
|
||||
raise ValueError(f"HTTP error {e.response.status_code}: {str(e)}")
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error: {e}")
|
||||
raise ValueError(f"Request error: {str(e)}")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON decode error: {e}")
|
||||
raise ValueError(f"Invalid JSON response: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
raise ValueError(f"Error communicating with A2A server: {str(e)}")
|
||||
print(f"A2A agent initialized for URL: {agent_card_url}")
|
||||
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
@ -145,10 +64,15 @@ class A2ACustomAgent(BaseAgent):
|
||||
This method follows the pattern of implementing custom agents,
|
||||
sending the user's message to the A2A service and monitoring the response.
|
||||
"""
|
||||
# 1. Yield the initial event
|
||||
yield Event(author=self.name)
|
||||
|
||||
try:
|
||||
# Prepare the base URL for the A2A
|
||||
url = self.agent_card_url
|
||||
|
||||
# Ensure that there is no /.well-known/agent.json in the url
|
||||
if "/.well-known/agent.json" in url:
|
||||
url = url.split("/.well-known/agent.json")[0]
|
||||
|
||||
# 2. Extract the user's message from the context
|
||||
user_message = None
|
||||
|
||||
@ -157,7 +81,7 @@ class A2ACustomAgent(BaseAgent):
|
||||
for event in reversed(ctx.session.events):
|
||||
if event.author == "user" and event.content and event.content.parts:
|
||||
user_message = event.content.parts[0].text
|
||||
logger.info("Message found in session events")
|
||||
print("Message found in session events")
|
||||
break
|
||||
|
||||
# Check in the session state if the message was not found in the events
|
||||
@ -167,21 +91,8 @@ class A2ACustomAgent(BaseAgent):
|
||||
elif "message" in ctx.session.state:
|
||||
user_message = ctx.session.state["message"]
|
||||
|
||||
if not user_message:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
role="agent",
|
||||
parts=[
|
||||
Part(text="Error: No message found to send to A2A agent")
|
||||
],
|
||||
),
|
||||
)
|
||||
yield Event(author=self.name) # Final event
|
||||
return
|
||||
|
||||
# 3. Create and send the task to the A2A agent
|
||||
logger.info(f"Sending task to A2A agent: {user_message[:100]}...")
|
||||
print(f"Sending task to A2A agent: {user_message[:100]}...")
|
||||
|
||||
# Use the session ID as a stable identifier
|
||||
session_id = (
|
||||
@ -190,331 +101,63 @@ class A2ACustomAgent(BaseAgent):
|
||||
else str(uuid4())
|
||||
)
|
||||
task_id = str(uuid4())
|
||||
request_id = str(uuid4())
|
||||
|
||||
try:
|
||||
# Format message according to A2A protocol
|
||||
formatted_message = Message(
|
||||
|
||||
formatted_message: Message = Message(
|
||||
role="user",
|
||||
parts=[TextPart(type="text", text=user_message)],
|
||||
)
|
||||
|
||||
# Prepare standard params for A2A
|
||||
task_params = {
|
||||
"id": task_id,
|
||||
"sessionId": session_id,
|
||||
"message": formatted_message,
|
||||
"acceptedOutputModes": ["text"],
|
||||
}
|
||||
|
||||
# Emit processing message
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
role="agent", parts=[Part(text="Processing your request...")]
|
||||
),
|
||||
request: SendTaskRequest = SendTaskRequest(
|
||||
params={
|
||||
"message": formatted_message,
|
||||
"sessionId": session_id,
|
||||
"id": task_id,
|
||||
}
|
||||
)
|
||||
|
||||
if self.streaming:
|
||||
# Use streaming mode
|
||||
request = SendTaskStreamingRequest(
|
||||
id=request_id, params=task_params
|
||||
)
|
||||
print(f"Request send task: {request.model_dump()}")
|
||||
|
||||
# Handle streaming response
|
||||
accumulated_response = ""
|
||||
# REQUEST POST to url when jsonrpc is 2.0
|
||||
task_result = await httpx.AsyncClient().post(
|
||||
url, json=request.model_dump(), timeout=self.timeout
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=None) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
self.base_url,
|
||||
json=request.model_dump(),
|
||||
headers=self.headers,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
print(f"Task response: {task_result.json()}")
|
||||
print(f"Task sent successfully, ID: {task_id}")
|
||||
|
||||
# Process SSE events
|
||||
async for line in response.aiter_lines():
|
||||
if not line or line.strip() == "":
|
||||
continue
|
||||
agent_response_parts = task_result.json()["result"]["status"][
|
||||
"message"
|
||||
]["parts"]
|
||||
|
||||
if line.startswith("data:"):
|
||||
data = line[5:].strip()
|
||||
try:
|
||||
event_data = json.loads(data)
|
||||
logger.debug(f"SSE event: {event_data}")
|
||||
parts = [Part(text=part["text"]) for part in agent_response_parts]
|
||||
|
||||
# Process artifacts
|
||||
if (
|
||||
"result" in event_data
|
||||
and "artifact" in event_data["result"]
|
||||
):
|
||||
artifact = event_data["result"][
|
||||
"artifact"
|
||||
]
|
||||
if artifact and "parts" in artifact:
|
||||
for part in artifact["parts"]:
|
||||
if (
|
||||
"text" in part
|
||||
and part["text"]
|
||||
):
|
||||
accumulated_response += (
|
||||
part["text"]
|
||||
)
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(role="agent", parts=parts),
|
||||
)
|
||||
|
||||
# Emit incremental update
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
role="agent",
|
||||
parts=[
|
||||
Part(
|
||||
text=accumulated_response
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Check if task is complete
|
||||
if (
|
||||
"result" in event_data
|
||||
and "status" in event_data["result"]
|
||||
and "final" in event_data["result"]
|
||||
and event_data["result"]["final"]
|
||||
):
|
||||
logger.info(
|
||||
"Task completed in streaming mode"
|
||||
)
|
||||
break
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
f"Error parsing SSE event: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streaming mode: {e}")
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
role="agent",
|
||||
parts=[Part(text=f"Error in streaming mode: {str(e)}")],
|
||||
),
|
||||
)
|
||||
|
||||
# If we have a response, we're done
|
||||
if accumulated_response:
|
||||
yield Event(author=self.name) # Final event
|
||||
return
|
||||
|
||||
else:
|
||||
# Use non-streaming mode
|
||||
request = SendTaskRequest(id=request_id, params=task_params)
|
||||
|
||||
# Make the request
|
||||
response_data = await self._send_jsonrpc_request(request)
|
||||
|
||||
# Process the response
|
||||
if "result" in response_data and response_data["result"]:
|
||||
task_result = response_data["result"]
|
||||
|
||||
# Extract message from response
|
||||
if (
|
||||
"status" in task_result
|
||||
and "message" in task_result["status"]
|
||||
and "parts" in task_result["status"]["message"]
|
||||
):
|
||||
|
||||
parts = task_result["status"]["message"]["parts"]
|
||||
for part in parts:
|
||||
if "text" in part and part["text"]:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
role="agent",
|
||||
parts=[Part(text=part["text"])],
|
||||
),
|
||||
)
|
||||
yield Event(author=self.name) # Final event
|
||||
return
|
||||
|
||||
# If we reach here, we need to poll for results
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < self.timeout:
|
||||
try:
|
||||
# Check current status
|
||||
status_request = GetTaskRequest(
|
||||
id=request_id, params={"id": task_id, "historyLength": 10}
|
||||
)
|
||||
|
||||
task_status = await self._send_jsonrpc_request(status_request)
|
||||
|
||||
if "result" not in task_status or not task_status["result"]:
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
continue
|
||||
|
||||
current_state = None
|
||||
if (
|
||||
"status" in task_status["result"]
|
||||
and task_status["result"]["status"]
|
||||
):
|
||||
if "state" in task_status["result"]["status"]:
|
||||
current_state = task_status["result"]["status"]["state"]
|
||||
|
||||
logger.info(f"Task status {task_id}: {current_state}")
|
||||
|
||||
# Check if the task was completed
|
||||
if current_state in [
|
||||
TaskState.COMPLETED,
|
||||
TaskState.FAILED,
|
||||
TaskState.CANCELED,
|
||||
]:
|
||||
if current_state == TaskState.COMPLETED:
|
||||
# Extract the response
|
||||
if (
|
||||
"status" in task_status["result"]
|
||||
and "message" in task_status["result"]["status"]
|
||||
and "parts"
|
||||
in task_status["result"]["status"]["message"]
|
||||
):
|
||||
|
||||
# Convert A2A parts to ADK
|
||||
response_parts = []
|
||||
for part in task_status["result"]["status"][
|
||||
"message"
|
||||
]["parts"]:
|
||||
if "text" in part and part["text"]:
|
||||
response_parts.append(
|
||||
Part(text=part["text"])
|
||||
)
|
||||
elif "data" in part:
|
||||
try:
|
||||
json_text = json.dumps(
|
||||
part["data"],
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
response_parts.append(
|
||||
Part(
|
||||
text=f"```json\n{json_text}\n```"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
response_parts.append(
|
||||
Part(text="[Unserializable data]")
|
||||
)
|
||||
|
||||
if response_parts:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
role="agent", parts=response_parts
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
role="agent",
|
||||
parts=[
|
||||
Part(
|
||||
text="Empty response from agent."
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
role="agent",
|
||||
parts=[
|
||||
Part(
|
||||
text="Task completed, but no response message."
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
elif current_state == TaskState.FAILED:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
role="agent",
|
||||
parts=[
|
||||
Part(
|
||||
text="The task failed during processing."
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
else: # CANCELED
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
role="agent",
|
||||
parts=[Part(text="The task was canceled.")],
|
||||
),
|
||||
)
|
||||
|
||||
# Store in the session state for future reference
|
||||
if ctx.session:
|
||||
try:
|
||||
ctx.session.state["a2a_task_result"] = task_status[
|
||||
"result"
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
break # Exit the loop of checking
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking task: {str(e)}")
|
||||
|
||||
# If the timeout was exceeded, inform the user
|
||||
if time.time() - start_time > self.max_wait_time:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
role="agent",
|
||||
parts=[Part(text=f"Error checking task: {str(e)}")],
|
||||
),
|
||||
)
|
||||
break
|
||||
|
||||
# Wait before the next check
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
|
||||
# If the timeout was exceeded
|
||||
if time.time() - start_time >= self.timeout:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
role="agent",
|
||||
parts=[
|
||||
Part(
|
||||
text="The operation exceeded the timeout. Please try again later."
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
# Run sub-agents
|
||||
for sub_agent in self.sub_agents:
|
||||
async for event in sub_agent.run_async(ctx):
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending task: {str(e)}")
|
||||
error_msg = f"Error sending request: {str(e)}"
|
||||
print(error_msg)
|
||||
print(f"Error type: {type(e).__name__}")
|
||||
print(f"Error details: {str(e)}")
|
||||
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
role="agent",
|
||||
parts=[
|
||||
Part(text=f"Error interacting with A2A agent: {str(e)}")
|
||||
],
|
||||
),
|
||||
content=Content(role="agent", parts=[Part(text=error_msg)]),
|
||||
)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
# Handle any uncaught error
|
||||
logger.error(f"Error executing A2A agent: {str(e)}")
|
||||
print(f"Error executing A2A agent: {str(e)}")
|
||||
yield Event(
|
||||
author=self.name,
|
||||
content=Content(
|
||||
@ -522,7 +165,3 @@ class A2ACustomAgent(BaseAgent):
|
||||
parts=[Part(text=f"Error interacting with A2A agent: {str(e)}")],
|
||||
),
|
||||
)
|
||||
|
||||
finally:
|
||||
# Ensure that the final event is always generated
|
||||
yield Event(author=self.name)
|
||||
|
@ -204,19 +204,23 @@ class AgentBuilder:
|
||||
raise ValueError("agent_card_url is required for a2a agents")
|
||||
|
||||
try:
|
||||
sub_agents = []
|
||||
if root_agent.config.get("sub_agents"):
|
||||
sub_agents_with_stacks = await self._get_sub_agents(
|
||||
root_agent.config.get("sub_agents")
|
||||
)
|
||||
sub_agents = [agent for agent, _ in sub_agents_with_stacks]
|
||||
|
||||
config = root_agent.config or {}
|
||||
poll_interval = config.get("poll_interval", 1.0)
|
||||
max_wait_time = config.get("max_wait_time", 60)
|
||||
timeout = config.get("timeout", 300)
|
||||
|
||||
a2a_agent = A2ACustomAgent(
|
||||
name=root_agent.name,
|
||||
agent_card_url=root_agent.agent_card_url,
|
||||
poll_interval=poll_interval,
|
||||
max_wait_time=max_wait_time,
|
||||
timeout=timeout,
|
||||
description=root_agent.description
|
||||
or f"A2A Agent for {root_agent.name}",
|
||||
sub_agents=sub_agents,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
@ -72,7 +72,7 @@ async def run_agent(
|
||||
content = Content(role="user", parts=[Part(text=message)])
|
||||
logger.info("Starting agent execution")
|
||||
|
||||
final_response_text = None
|
||||
final_response_text = "No final response captured."
|
||||
try:
|
||||
response_queue = asyncio.Queue()
|
||||
execution_completed = asyncio.Event()
|
||||
|
Loading…
Reference in New Issue
Block a user