feat(a2a): implement A2A custom agent and enhance agent builder for A2A type agents

This commit is contained in:
Davidson Gomes 2025-04-30 20:16:29 -03:00
parent c14d23333c
commit 8de86c22ee
8 changed files with 524 additions and 36 deletions

View File

@ -15,7 +15,6 @@ from src.services import (
agent_service,
mcp_server_service,
)
from src.models.models import Agent as AgentModel
import logging
logger = logging.getLogger(__name__)

View File

@ -4,6 +4,137 @@ A2A (Agent-to-Agent) schema package.
This package contains Pydantic schema definitions for the A2A protocol.
"""
from src.schemas.a2a.types import *
from src.schemas.a2a.exceptions import *
from src.schemas.a2a.validators import *
from src.schemas.a2a.types import (
TaskState,
TextPart,
FileContent,
FilePart,
DataPart,
Part,
Message,
TaskStatus,
Artifact,
Task,
TaskStatusUpdateEvent,
TaskArtifactUpdateEvent,
AuthenticationInfo,
PushNotificationConfig,
TaskIdParams,
TaskQueryParams,
TaskSendParams,
TaskPushNotificationConfig,
JSONRPCMessage,
JSONRPCRequest,
JSONRPCResponse,
JSONRPCError,
SendTaskRequest,
SendTaskResponse,
SendTaskStreamingRequest,
SendTaskStreamingResponse,
GetTaskRequest,
GetTaskResponse,
CancelTaskRequest,
CancelTaskResponse,
SetTaskPushNotificationRequest,
SetTaskPushNotificationResponse,
GetTaskPushNotificationRequest,
GetTaskPushNotificationResponse,
TaskResubscriptionRequest,
A2ARequest,
AgentProvider,
AgentCapabilities,
AgentAuthentication,
AgentSkill,
AgentCard,
)
from src.schemas.a2a.exceptions import (
JSONParseError,
InvalidRequestError,
MethodNotFoundError,
InvalidParamsError,
InternalError,
TaskNotFoundError,
TaskNotCancelableError,
PushNotificationNotSupportedError,
UnsupportedOperationError,
ContentTypeNotSupportedError,
A2AClientError,
A2AClientHTTPError,
A2AClientJSONError,
MissingAPIKeyError,
)
from src.schemas.a2a.validators import (
validate_base64,
validate_file_content,
validate_message_parts,
text_to_parts,
parts_to_text,
)
__all__ = [
# From types
"TaskState",
"TextPart",
"FileContent",
"FilePart",
"DataPart",
"Part",
"Message",
"TaskStatus",
"Artifact",
"Task",
"TaskStatusUpdateEvent",
"TaskArtifactUpdateEvent",
"AuthenticationInfo",
"PushNotificationConfig",
"TaskIdParams",
"TaskQueryParams",
"TaskSendParams",
"TaskPushNotificationConfig",
"JSONRPCMessage",
"JSONRPCRequest",
"JSONRPCResponse",
"JSONRPCError",
"SendTaskRequest",
"SendTaskResponse",
"SendTaskStreamingRequest",
"SendTaskStreamingResponse",
"GetTaskRequest",
"GetTaskResponse",
"CancelTaskRequest",
"CancelTaskResponse",
"SetTaskPushNotificationRequest",
"SetTaskPushNotificationResponse",
"GetTaskPushNotificationRequest",
"GetTaskPushNotificationResponse",
"TaskResubscriptionRequest",
"A2ARequest",
"AgentProvider",
"AgentCapabilities",
"AgentAuthentication",
"AgentSkill",
"AgentCard",
# From exceptions
"JSONParseError",
"InvalidRequestError",
"MethodNotFoundError",
"InvalidParamsError",
"InternalError",
"TaskNotFoundError",
"TaskNotCancelableError",
"PushNotificationNotSupportedError",
"UnsupportedOperationError",
"ContentTypeNotSupportedError",
"A2AClientError",
"A2AClientHTTPError",
"A2AClientJSONError",
"MissingAPIKeyError",
# From validators
"validate_base64",
"validate_file_content",
"validate_message_parts",
"text_to_parts",
"parts_to_text",
]

329
src/services/a2a_agent.py Normal file
View File

@ -0,0 +1,329 @@
from google.adk.agents import BaseAgent
from google.adk.agents.invocation_context import InvocationContext
from google.adk.events import Event
from google.genai.types import Content, Part
from typing import AsyncGenerator
import json
import asyncio
import time
from src.schemas.a2a.types import (
GetTaskRequest,
SendTaskRequest,
Message,
TextPart,
TaskState,
)
import httpx
from uuid import uuid4
class A2ACustomAgent(BaseAgent):
"""
Custom agent that implements the A2A protocol directly.
This agent implements the interaction with an external A2A service.
"""
# Field declarations for Pydantic
agent_card_url: str
poll_interval: float
max_wait_time: int
timeout: int
def __init__(
self,
name: str,
agent_card_url: str,
poll_interval: float = 1.0,
max_wait_time: int = 60,
timeout: int = 300,
**kwargs,
):
"""
Initialize the A2A agent.
Args:
name: Agent name
agent_card_url: A2A agent card URL
poll_interval: Status check interval (seconds)
max_wait_time: Maximum wait time for a task (seconds)
timeout: Maximum execution time (seconds)
"""
# Initialize base class
super().__init__(
name=name,
agent_card_url=agent_card_url,
poll_interval=poll_interval,
max_wait_time=max_wait_time,
timeout=timeout,
**kwargs,
)
print(f"A2A agent initialized for URL: {agent_card_url}")
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
"""
Implementation of the A2A protocol according to the Google ADK documentation.
This method follows the pattern of implementing custom agents,
sending the user's message to the A2A service and monitoring the response.
"""
# 1. Yield the initial event
yield Event(author=self.name)
try:
# Prepare the base URL for the A2A
url = self.agent_card_url
# Ensure that there is no /.well-known/agent.json in the url
if "/.well-known/agent.json" in url:
url = url.split("/.well-known/agent.json")[0]
# 2. Extract the user's message from the context
user_message = None
# Search for the user's message in the session events
if ctx.session and hasattr(ctx.session, "events") and ctx.session.events:
for event in reversed(ctx.session.events):
if event.author == "user" and event.content and event.content.parts:
user_message = event.content.parts[0].text
print("Message found in session events")
break
# Check in the session state if the message was not found in the events
if not user_message and ctx.session and ctx.session.state:
if "user_message" in ctx.session.state:
user_message = ctx.session.state["user_message"]
elif "message" in ctx.session.state:
user_message = ctx.session.state["message"]
# 3. Create and send the task to the A2A agent
print(f"Sending task to A2A agent: {user_message[:100]}...")
# Use the session ID as a stable identifier
session_id = (
str(ctx.session.id)
if ctx.session and hasattr(ctx.session, "id")
else str(uuid4())
)
task_id = str(uuid4())
try:
formatted_message: Message = Message(
role="user",
parts=[TextPart(type="text", text=user_message)],
)
request: SendTaskRequest = SendTaskRequest(
params={
"message": formatted_message,
"sessionId": session_id,
"id": task_id,
}
)
print(f"Request send task: {request.model_dump()}")
# REQUEST POST to url when jsonrpc is 2.0
task_result = await httpx.AsyncClient().post(
url, json=request.model_dump(), timeout=30
)
print(f"Task response: {task_result.json()}")
print(f"Task sent successfully, ID: {task_id}")
yield Event(
author=self.name,
content=Content(
role="agent", parts=[Part(text="Processing request...")]
),
)
except Exception as e:
error_msg = f"Error sending request: {str(e)}"
print(error_msg)
print(f"Error type: {type(e).__name__}")
print(f"Error details: {str(e)}")
yield Event(
author=self.name,
content=Content(role="agent", parts=[Part(text=str(e))]),
)
yield Event(author=self.name) # Final event
return
start_time = time.time()
while time.time() - start_time < self.timeout:
try:
# Check current status
request: GetTaskRequest = GetTaskRequest(params={"id": task_id})
task_status_response = await httpx.AsyncClient().post(
url, json=request.model_dump(), timeout=30
)
print(f"Response get task: {task_status_response.json()}")
task_status = task_status_response.json()
if "result" not in task_status or not task_status["result"]:
await asyncio.sleep(self.poll_interval)
continue
current_state = None
if (
"status" in task_status["result"]
and task_status["result"]["status"]
):
if "state" in task_status["result"]["status"]:
current_state = task_status["result"]["status"]["state"]
print(f"Task status {task_id}: {current_state}")
# Check if the task was completed
if current_state in [
TaskState.COMPLETED,
TaskState.FAILED,
TaskState.CANCELED,
]:
if current_state == TaskState.COMPLETED:
# Extract the response
if (
"status" in task_status["result"]
and "message" in task_status["result"]["status"]
and "parts"
in task_status["result"]["status"]["message"]
):
# Convert A2A parts to ADK
response_parts = []
for part in task_status["result"]["status"]["message"][
"parts"
]:
if "text" in part and part["text"]:
response_parts.append(Part(text=part["text"]))
elif "data" in part:
try:
json_text = json.dumps(
part["data"],
ensure_ascii=False,
indent=2,
)
response_parts.append(
Part(text=f"```json\n{json_text}\n```")
)
except Exception:
response_parts.append(
Part(text="[Unserializable data]")
)
if response_parts:
yield Event(
author=self.name,
content=Content(
role="agent", parts=response_parts
),
)
else:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(text="Empty response from agent.")
],
),
)
else:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(
text="Task completed, but no response message."
)
],
),
)
elif current_state == TaskState.FAILED:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(text="The task failed during processing.")
],
),
)
else: # CANCELED
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[Part(text="The task was canceled.")],
),
)
# Store in the session state for future reference
if ctx.session:
try:
ctx.session.state["a2a_task_result"] = task_status[
"result"
]
except Exception:
pass
break # Exit the loop of checking
except Exception as e:
print(f"Error checking task: {str(e)}")
# If the timeout was exceeded, inform the user
if time.time() - start_time > self.max_wait_time:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[Part(text=f"Error checking task: {str(e)}")],
),
)
break
# Wait before the next check
await asyncio.sleep(self.poll_interval)
# If the timeout was exceeded
if time.time() - start_time >= self.timeout:
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[
Part(
text="The operation exceeded the timeout. Please try again later."
)
],
),
)
except Exception as e:
# Handle any uncaught error
print(f"Error executing A2A agent: {str(e)}")
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[Part(text=f"Error interacting with A2A agent: {str(e)}")],
),
)
finally:
# Ensure that the final event is always generated
yield Event(author=self.name)

View File

@ -5,10 +5,8 @@ This module implements a JSON-RPC compatible server for the A2A protocol,
that manages agent tasks, streaming events and push notifications.
"""
import asyncio
import json
import logging
import uuid
from datetime import datetime
from typing import (
Any,
@ -16,18 +14,14 @@ from typing import (
List,
Optional,
AsyncGenerator,
Callable,
Union,
AsyncIterable,
)
import httpx
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse, Response
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from src.schemas.a2a.types import A2ARequest
from src.services.agent_runner import run_agent
from src.services.a2a_integration_service import (
AgentRunnerAdapter,
StreamingServiceAdapter,
@ -42,9 +36,7 @@ from src.schemas.a2a.types import (
SetTaskPushNotificationRequest,
GetTaskPushNotificationRequest,
TaskResubscriptionRequest,
TaskSendParams,
)
from src.utils.a2a_utils import are_modalities_compatible
logger = logging.getLogger(__name__)
@ -579,7 +571,9 @@ class A2AServer:
"error": {
"code": -32601,
"message": "Method not found",
"data": {"detail": f"Method not supported"},
"data": {
"detail": f"Method not supported: {method}"
},
},
},
)

View File

@ -6,12 +6,9 @@ including execution, streaming, push notifications, status queries, and cancella
"""
import asyncio
import json
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Union, AsyncIterable
from sqlalchemy.orm import Session
from typing import Any, Dict, Union, AsyncIterable
from src.schemas.a2a.exceptions import (
TaskNotFoundError,
@ -23,8 +20,6 @@ from src.schemas.a2a.exceptions import (
from src.schemas.a2a.types import (
JSONRPCResponse,
TaskIdParams,
TaskQueryParams,
GetTaskRequest,
SendTaskRequest,
CancelTaskRequest,
@ -55,9 +50,6 @@ from src.services.redis_cache_service import RedisCacheService
from src.utils.a2a_utils import (
are_modalities_compatible,
new_incompatible_types_error,
new_not_implemented_error,
create_task_id,
format_error_response,
)
logger = logging.getLogger(__name__)

View File

@ -1,12 +1,13 @@
from typing import List, Optional, Tuple
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents import SequentialAgent, ParallelAgent, LoopAgent
from google.adk.agents import SequentialAgent, ParallelAgent, LoopAgent, BaseAgent
from google.adk.models.lite_llm import LiteLlm
from src.utils.logger import setup_logger
from src.core.exceptions import AgentNotFoundError
from src.services.agent_service import get_agent
from src.services.custom_tools import CustomToolBuilder
from src.services.mcp_service import MCPService
from src.services.a2a_agent import A2ACustomAgent
from sqlalchemy.orm import Session
from contextlib import AsyncExitStack
from google.adk.tools import load_memory
@ -87,12 +88,19 @@ class AgentBuilder:
if agent is None:
raise AgentNotFoundError(f"Agent with ID {sub_agent_id} not found")
if agent.type != "llm":
raise ValueError(
f"Agent {agent.name} (ID: {agent.id}) is not an LLM agent"
)
if agent.type == "llm":
sub_agent, exit_stack = await self._create_llm_agent(agent)
elif agent.type == "a2a":
sub_agent, exit_stack = await self.build_a2a_agent(agent)
elif agent.type == "sequential":
sub_agent, exit_stack = await self.build_composite_agent(agent)
elif agent.type == "parallel":
sub_agent, exit_stack = await self.build_composite_agent(agent)
elif agent.type == "loop":
sub_agent, exit_stack = await self.build_composite_agent(agent)
else:
raise ValueError(f"Invalid agent type: {agent.type}")
sub_agent, exit_stack = await self._create_llm_agent(agent)
sub_agents.append((sub_agent, exit_stack))
return sub_agents
@ -116,6 +124,41 @@ class AgentBuilder:
return root_llm_agent, exit_stack
async def build_a2a_agent(
self, root_agent
) -> Tuple[BaseAgent, Optional[AsyncExitStack]]:
"""Build an A2A agent with its sub-agents."""
logger.info(f"Creating A2A agent from {root_agent.agent_card_url}")
if not root_agent.agent_card_url:
raise ValueError("agent_card_url is required for a2a agents")
try:
config = root_agent.config or {}
poll_interval = config.get("poll_interval", 1.0)
max_wait_time = config.get("max_wait_time", 60)
timeout = config.get("timeout", 300)
a2a_agent = A2ACustomAgent(
name=root_agent.name,
agent_card_url=root_agent.agent_card_url,
poll_interval=poll_interval,
max_wait_time=max_wait_time,
timeout=timeout,
description=root_agent.description
or f"A2A Agent for {root_agent.name}",
)
logger.info(
f"A2A agent created successfully: {root_agent.name} ({root_agent.agent_card_url})"
)
return a2a_agent, None
except Exception as e:
logger.error(f"Error building A2A agent: {str(e)}")
raise ValueError(f"Error building A2A agent: {str(e)}")
async def build_composite_agent(
self, root_agent
) -> Tuple[SequentialAgent | ParallelAgent | LoopAgent, Optional[AsyncExitStack]]:
@ -161,13 +204,14 @@ class AgentBuilder:
else:
raise ValueError(f"Invalid agent type: {root_agent.type}")
async def build_agent(
self, root_agent
) -> Tuple[
LlmAgent | SequentialAgent | ParallelAgent | LoopAgent, Optional[AsyncExitStack]
async def build_agent(self, root_agent) -> Tuple[
LlmAgent | SequentialAgent | ParallelAgent | LoopAgent | A2ACustomAgent,
Optional[AsyncExitStack],
]:
"""Build the appropriate agent based on the type of the root agent."""
if root_agent.type == "llm":
return await self.build_llm_agent(root_agent)
elif root_agent.type == "a2a":
return await self.build_a2a_agent(root_agent)
else:
return await self.build_composite_agent(root_agent)

View File

@ -17,7 +17,7 @@ import jwt
from jwt import PyJWK, PyJWKClient
from fastapi import Request
from starlette.responses import JSONResponse
from typing import Dict, Any, Optional
from typing import Dict, Any
logger = logging.getLogger(__name__)
AUTH_HEADER_PREFIX = "Bearer "

View File

@ -7,11 +7,10 @@ push notification configurations, and other A2A-related data.
import json
import logging
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
import asyncio
import redis.asyncio as aioredis
from redis.exceptions import RedisError
from src.config.redis import get_redis_config, get_a2a_config
from src.config.redis import get_redis_config
import threading
import time