diff --git a/src/api/agent_routes.py b/src/api/agent_routes.py index 0bdd6d9e..4e222be3 100644 --- a/src/api/agent_routes.py +++ b/src/api/agent_routes.py @@ -15,7 +15,6 @@ from src.services import ( agent_service, mcp_server_service, ) -from src.models.models import Agent as AgentModel import logging logger = logging.getLogger(__name__) diff --git a/src/schemas/a2a/__init__.py b/src/schemas/a2a/__init__.py index 68293cc2..35c904ff 100644 --- a/src/schemas/a2a/__init__.py +++ b/src/schemas/a2a/__init__.py @@ -4,6 +4,137 @@ A2A (Agent-to-Agent) schema package. This package contains Pydantic schema definitions for the A2A protocol. """ -from src.schemas.a2a.types import * -from src.schemas.a2a.exceptions import * -from src.schemas.a2a.validators import * +from src.schemas.a2a.types import ( + TaskState, + TextPart, + FileContent, + FilePart, + DataPart, + Part, + Message, + TaskStatus, + Artifact, + Task, + TaskStatusUpdateEvent, + TaskArtifactUpdateEvent, + AuthenticationInfo, + PushNotificationConfig, + TaskIdParams, + TaskQueryParams, + TaskSendParams, + TaskPushNotificationConfig, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCError, + SendTaskRequest, + SendTaskResponse, + SendTaskStreamingRequest, + SendTaskStreamingResponse, + GetTaskRequest, + GetTaskResponse, + CancelTaskRequest, + CancelTaskResponse, + SetTaskPushNotificationRequest, + SetTaskPushNotificationResponse, + GetTaskPushNotificationRequest, + GetTaskPushNotificationResponse, + TaskResubscriptionRequest, + A2ARequest, + AgentProvider, + AgentCapabilities, + AgentAuthentication, + AgentSkill, + AgentCard, +) + +from src.schemas.a2a.exceptions import ( + JSONParseError, + InvalidRequestError, + MethodNotFoundError, + InvalidParamsError, + InternalError, + TaskNotFoundError, + TaskNotCancelableError, + PushNotificationNotSupportedError, + UnsupportedOperationError, + ContentTypeNotSupportedError, + A2AClientError, + A2AClientHTTPError, + A2AClientJSONError, + MissingAPIKeyError, +) + +from src.schemas.a2a.validators import ( + validate_base64, + validate_file_content, + validate_message_parts, + text_to_parts, + parts_to_text, +) + +__all__ = [ + # From types + "TaskState", + "TextPart", + "FileContent", + "FilePart", + "DataPart", + "Part", + "Message", + "TaskStatus", + "Artifact", + "Task", + "TaskStatusUpdateEvent", + "TaskArtifactUpdateEvent", + "AuthenticationInfo", + "PushNotificationConfig", + "TaskIdParams", + "TaskQueryParams", + "TaskSendParams", + "TaskPushNotificationConfig", + "JSONRPCMessage", + "JSONRPCRequest", + "JSONRPCResponse", + "JSONRPCError", + "SendTaskRequest", + "SendTaskResponse", + "SendTaskStreamingRequest", + "SendTaskStreamingResponse", + "GetTaskRequest", + "GetTaskResponse", + "CancelTaskRequest", + "CancelTaskResponse", + "SetTaskPushNotificationRequest", + "SetTaskPushNotificationResponse", + "GetTaskPushNotificationRequest", + "GetTaskPushNotificationResponse", + "TaskResubscriptionRequest", + "A2ARequest", + "AgentProvider", + "AgentCapabilities", + "AgentAuthentication", + "AgentSkill", + "AgentCard", + # From exceptions + "JSONParseError", + "InvalidRequestError", + "MethodNotFoundError", + "InvalidParamsError", + "InternalError", + "TaskNotFoundError", + "TaskNotCancelableError", + "PushNotificationNotSupportedError", + "UnsupportedOperationError", + "ContentTypeNotSupportedError", + "A2AClientError", + "A2AClientHTTPError", + "A2AClientJSONError", + "MissingAPIKeyError", + # From validators + "validate_base64", + "validate_file_content", + "validate_message_parts", + "text_to_parts", + "parts_to_text", +] diff --git a/src/services/a2a_agent.py b/src/services/a2a_agent.py new file mode 100644 index 00000000..b57eb573 --- /dev/null +++ b/src/services/a2a_agent.py @@ -0,0 +1,329 @@ +from google.adk.agents import BaseAgent +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events import Event +from google.genai.types import Content, Part +from typing import AsyncGenerator +import json +import asyncio +import time + +from src.schemas.a2a.types import ( + GetTaskRequest, + SendTaskRequest, + Message, + TextPart, + TaskState, +) + +import httpx + +from uuid import uuid4 + + +class A2ACustomAgent(BaseAgent): + """ + Custom agent that implements the A2A protocol directly. + + This agent implements the interaction with an external A2A service. + """ + + # Field declarations for Pydantic + agent_card_url: str + poll_interval: float + max_wait_time: int + timeout: int + + def __init__( + self, + name: str, + agent_card_url: str, + poll_interval: float = 1.0, + max_wait_time: int = 60, + timeout: int = 300, + **kwargs, + ): + """ + Initialize the A2A agent. + + Args: + name: Agent name + agent_card_url: A2A agent card URL + poll_interval: Status check interval (seconds) + max_wait_time: Maximum wait time for a task (seconds) + timeout: Maximum execution time (seconds) + """ + # Initialize base class + super().__init__( + name=name, + agent_card_url=agent_card_url, + poll_interval=poll_interval, + max_wait_time=max_wait_time, + timeout=timeout, + **kwargs, + ) + + print(f"A2A agent initialized for URL: {agent_card_url}") + + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """ + Implementation of the A2A protocol according to the Google ADK documentation. + + This method follows the pattern of implementing custom agents, + sending the user's message to the A2A service and monitoring the response. + """ + # 1. Yield the initial event + yield Event(author=self.name) + + try: + # Prepare the base URL for the A2A + url = self.agent_card_url + + # Ensure that there is no /.well-known/agent.json in the url + if "/.well-known/agent.json" in url: + url = url.split("/.well-known/agent.json")[0] + + # 2. Extract the user's message from the context + user_message = None + + # Search for the user's message in the session events + if ctx.session and hasattr(ctx.session, "events") and ctx.session.events: + for event in reversed(ctx.session.events): + if event.author == "user" and event.content and event.content.parts: + user_message = event.content.parts[0].text + print("Message found in session events") + break + + # Check in the session state if the message was not found in the events + if not user_message and ctx.session and ctx.session.state: + if "user_message" in ctx.session.state: + user_message = ctx.session.state["user_message"] + elif "message" in ctx.session.state: + user_message = ctx.session.state["message"] + + # 3. Create and send the task to the A2A agent + print(f"Sending task to A2A agent: {user_message[:100]}...") + + # Use the session ID as a stable identifier + session_id = ( + str(ctx.session.id) + if ctx.session and hasattr(ctx.session, "id") + else str(uuid4()) + ) + task_id = str(uuid4()) + + try: + + formatted_message: Message = Message( + role="user", + parts=[TextPart(type="text", text=user_message)], + ) + + request: SendTaskRequest = SendTaskRequest( + params={ + "message": formatted_message, + "sessionId": session_id, + "id": task_id, + } + ) + + print(f"Request send task: {request.model_dump()}") + + # REQUEST POST to url when jsonrpc is 2.0 + task_result = await httpx.AsyncClient().post( + url, json=request.model_dump(), timeout=30 + ) + + print(f"Task response: {task_result.json()}") + print(f"Task sent successfully, ID: {task_id}") + + yield Event( + author=self.name, + content=Content( + role="agent", parts=[Part(text="Processing request...")] + ), + ) + except Exception as e: + error_msg = f"Error sending request: {str(e)}" + print(error_msg) + print(f"Error type: {type(e).__name__}") + print(f"Error details: {str(e)}") + + yield Event( + author=self.name, + content=Content(role="agent", parts=[Part(text=str(e))]), + ) + yield Event(author=self.name) # Final event + return + + start_time = time.time() + + while time.time() - start_time < self.timeout: + try: + # Check current status + request: GetTaskRequest = GetTaskRequest(params={"id": task_id}) + + task_status_response = await httpx.AsyncClient().post( + url, json=request.model_dump(), timeout=30 + ) + + print(f"Response get task: {task_status_response.json()}") + + task_status = task_status_response.json() + + if "result" not in task_status or not task_status["result"]: + await asyncio.sleep(self.poll_interval) + continue + + current_state = None + if ( + "status" in task_status["result"] + and task_status["result"]["status"] + ): + if "state" in task_status["result"]["status"]: + current_state = task_status["result"]["status"]["state"] + + print(f"Task status {task_id}: {current_state}") + + # Check if the task was completed + if current_state in [ + TaskState.COMPLETED, + TaskState.FAILED, + TaskState.CANCELED, + ]: + if current_state == TaskState.COMPLETED: + # Extract the response + if ( + "status" in task_status["result"] + and "message" in task_status["result"]["status"] + and "parts" + in task_status["result"]["status"]["message"] + ): + + # Convert A2A parts to ADK + response_parts = [] + for part in task_status["result"]["status"]["message"][ + "parts" + ]: + if "text" in part and part["text"]: + response_parts.append(Part(text=part["text"])) + elif "data" in part: + try: + json_text = json.dumps( + part["data"], + ensure_ascii=False, + indent=2, + ) + response_parts.append( + Part(text=f"```json\n{json_text}\n```") + ) + except Exception: + response_parts.append( + Part(text="[Unserializable data]") + ) + + if response_parts: + yield Event( + author=self.name, + content=Content( + role="agent", parts=response_parts + ), + ) + else: + yield Event( + author=self.name, + content=Content( + role="agent", + parts=[ + Part(text="Empty response from agent.") + ], + ), + ) + else: + yield Event( + author=self.name, + content=Content( + role="agent", + parts=[ + Part( + text="Task completed, but no response message." + ) + ], + ), + ) + elif current_state == TaskState.FAILED: + yield Event( + author=self.name, + content=Content( + role="agent", + parts=[ + Part(text="The task failed during processing.") + ], + ), + ) + else: # CANCELED + yield Event( + author=self.name, + content=Content( + role="agent", + parts=[Part(text="The task was canceled.")], + ), + ) + + # Store in the session state for future reference + if ctx.session: + try: + ctx.session.state["a2a_task_result"] = task_status[ + "result" + ] + except Exception: + pass + + break # Exit the loop of checking + + except Exception as e: + print(f"Error checking task: {str(e)}") + + # If the timeout was exceeded, inform the user + if time.time() - start_time > self.max_wait_time: + yield Event( + author=self.name, + content=Content( + role="agent", + parts=[Part(text=f"Error checking task: {str(e)}")], + ), + ) + break + + # Wait before the next check + await asyncio.sleep(self.poll_interval) + + # If the timeout was exceeded + if time.time() - start_time >= self.timeout: + yield Event( + author=self.name, + content=Content( + role="agent", + parts=[ + Part( + text="The operation exceeded the timeout. Please try again later." + ) + ], + ), + ) + + except Exception as e: + # Handle any uncaught error + print(f"Error executing A2A agent: {str(e)}") + yield Event( + author=self.name, + content=Content( + role="agent", + parts=[Part(text=f"Error interacting with A2A agent: {str(e)}")], + ), + ) + + finally: + # Ensure that the final event is always generated + yield Event(author=self.name) diff --git a/src/services/a2a_server_service.py b/src/services/a2a_server_service.py index 4e71d176..bddf1ac8 100644 --- a/src/services/a2a_server_service.py +++ b/src/services/a2a_server_service.py @@ -5,10 +5,8 @@ This module implements a JSON-RPC compatible server for the A2A protocol, that manages agent tasks, streaming events and push notifications. """ -import asyncio import json import logging -import uuid from datetime import datetime from typing import ( Any, @@ -16,18 +14,14 @@ from typing import ( List, Optional, AsyncGenerator, - Callable, Union, AsyncIterable, ) -import httpx from fastapi import Request from fastapi.responses import JSONResponse, StreamingResponse, Response -from pydantic import BaseModel, Field from sqlalchemy.orm import Session from src.schemas.a2a.types import A2ARequest -from src.services.agent_runner import run_agent from src.services.a2a_integration_service import ( AgentRunnerAdapter, StreamingServiceAdapter, @@ -42,9 +36,7 @@ from src.schemas.a2a.types import ( SetTaskPushNotificationRequest, GetTaskPushNotificationRequest, TaskResubscriptionRequest, - TaskSendParams, ) -from src.utils.a2a_utils import are_modalities_compatible logger = logging.getLogger(__name__) @@ -579,7 +571,9 @@ class A2AServer: "error": { "code": -32601, "message": "Method not found", - "data": {"detail": f"Method not supported"}, + "data": { + "detail": f"Method not supported: {method}" + }, }, }, ) diff --git a/src/services/a2a_task_manager_service.py b/src/services/a2a_task_manager_service.py index fc88fa43..6d5b926a 100644 --- a/src/services/a2a_task_manager_service.py +++ b/src/services/a2a_task_manager_service.py @@ -6,12 +6,9 @@ including execution, streaming, push notifications, status queries, and cancella """ import asyncio -import json import logging from datetime import datetime -from typing import Any, Dict, List, Optional, Union, AsyncIterable - -from sqlalchemy.orm import Session +from typing import Any, Dict, Union, AsyncIterable from src.schemas.a2a.exceptions import ( TaskNotFoundError, @@ -23,8 +20,6 @@ from src.schemas.a2a.exceptions import ( from src.schemas.a2a.types import ( JSONRPCResponse, - TaskIdParams, - TaskQueryParams, GetTaskRequest, SendTaskRequest, CancelTaskRequest, @@ -55,9 +50,6 @@ from src.services.redis_cache_service import RedisCacheService from src.utils.a2a_utils import ( are_modalities_compatible, new_incompatible_types_error, - new_not_implemented_error, - create_task_id, - format_error_response, ) logger = logging.getLogger(__name__) diff --git a/src/services/agent_builder.py b/src/services/agent_builder.py index 7f482fdf..c853b693 100644 --- a/src/services/agent_builder.py +++ b/src/services/agent_builder.py @@ -1,12 +1,13 @@ from typing import List, Optional, Tuple from google.adk.agents.llm_agent import LlmAgent -from google.adk.agents import SequentialAgent, ParallelAgent, LoopAgent +from google.adk.agents import SequentialAgent, ParallelAgent, LoopAgent, BaseAgent from google.adk.models.lite_llm import LiteLlm from src.utils.logger import setup_logger from src.core.exceptions import AgentNotFoundError from src.services.agent_service import get_agent from src.services.custom_tools import CustomToolBuilder from src.services.mcp_service import MCPService +from src.services.a2a_agent import A2ACustomAgent from sqlalchemy.orm import Session from contextlib import AsyncExitStack from google.adk.tools import load_memory @@ -87,12 +88,19 @@ class AgentBuilder: if agent is None: raise AgentNotFoundError(f"Agent with ID {sub_agent_id} not found") - if agent.type != "llm": - raise ValueError( - f"Agent {agent.name} (ID: {agent.id}) is not an LLM agent" - ) + if agent.type == "llm": + sub_agent, exit_stack = await self._create_llm_agent(agent) + elif agent.type == "a2a": + sub_agent, exit_stack = await self.build_a2a_agent(agent) + elif agent.type == "sequential": + sub_agent, exit_stack = await self.build_composite_agent(agent) + elif agent.type == "parallel": + sub_agent, exit_stack = await self.build_composite_agent(agent) + elif agent.type == "loop": + sub_agent, exit_stack = await self.build_composite_agent(agent) + else: + raise ValueError(f"Invalid agent type: {agent.type}") - sub_agent, exit_stack = await self._create_llm_agent(agent) sub_agents.append((sub_agent, exit_stack)) return sub_agents @@ -116,6 +124,41 @@ class AgentBuilder: return root_llm_agent, exit_stack + async def build_a2a_agent( + self, root_agent + ) -> Tuple[BaseAgent, Optional[AsyncExitStack]]: + """Build an A2A agent with its sub-agents.""" + logger.info(f"Creating A2A agent from {root_agent.agent_card_url}") + + if not root_agent.agent_card_url: + raise ValueError("agent_card_url is required for a2a agents") + + try: + config = root_agent.config or {} + poll_interval = config.get("poll_interval", 1.0) + max_wait_time = config.get("max_wait_time", 60) + timeout = config.get("timeout", 300) + + a2a_agent = A2ACustomAgent( + name=root_agent.name, + agent_card_url=root_agent.agent_card_url, + poll_interval=poll_interval, + max_wait_time=max_wait_time, + timeout=timeout, + description=root_agent.description + or f"A2A Agent for {root_agent.name}", + ) + + logger.info( + f"A2A agent created successfully: {root_agent.name} ({root_agent.agent_card_url})" + ) + + return a2a_agent, None + + except Exception as e: + logger.error(f"Error building A2A agent: {str(e)}") + raise ValueError(f"Error building A2A agent: {str(e)}") + async def build_composite_agent( self, root_agent ) -> Tuple[SequentialAgent | ParallelAgent | LoopAgent, Optional[AsyncExitStack]]: @@ -161,13 +204,14 @@ class AgentBuilder: else: raise ValueError(f"Invalid agent type: {root_agent.type}") - async def build_agent( - self, root_agent - ) -> Tuple[ - LlmAgent | SequentialAgent | ParallelAgent | LoopAgent, Optional[AsyncExitStack] + async def build_agent(self, root_agent) -> Tuple[ + LlmAgent | SequentialAgent | ParallelAgent | LoopAgent | A2ACustomAgent, + Optional[AsyncExitStack], ]: """Build the appropriate agent based on the type of the root agent.""" if root_agent.type == "llm": return await self.build_llm_agent(root_agent) + elif root_agent.type == "a2a": + return await self.build_a2a_agent(root_agent) else: return await self.build_composite_agent(root_agent) diff --git a/src/services/push_notification_auth_service.py b/src/services/push_notification_auth_service.py index ca4174ec..228a8215 100644 --- a/src/services/push_notification_auth_service.py +++ b/src/services/push_notification_auth_service.py @@ -17,7 +17,7 @@ import jwt from jwt import PyJWK, PyJWKClient from fastapi import Request from starlette.responses import JSONResponse -from typing import Dict, Any, Optional +from typing import Dict, Any logger = logging.getLogger(__name__) AUTH_HEADER_PREFIX = "Bearer " diff --git a/src/services/redis_cache_service.py b/src/services/redis_cache_service.py index 0c4bccda..72ed0ce2 100644 --- a/src/services/redis_cache_service.py +++ b/src/services/redis_cache_service.py @@ -7,11 +7,10 @@ push notification configurations, and other A2A-related data. import json import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import asyncio import redis.asyncio as aioredis -from redis.exceptions import RedisError -from src.config.redis import get_redis_config, get_a2a_config +from src.config.redis import get_redis_config import threading import time