feat(a2a): implement A2A custom agent and enhance agent builder for A2A type agents
This commit is contained in:
parent
c14d23333c
commit
8de86c22ee
@ -15,7 +15,6 @@ from src.services import (
|
|||||||
agent_service,
|
agent_service,
|
||||||
mcp_server_service,
|
mcp_server_service,
|
||||||
)
|
)
|
||||||
from src.models.models import Agent as AgentModel
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -4,6 +4,137 @@ A2A (Agent-to-Agent) schema package.
|
|||||||
This package contains Pydantic schema definitions for the A2A protocol.
|
This package contains Pydantic schema definitions for the A2A protocol.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.schemas.a2a.types import *
|
from src.schemas.a2a.types import (
|
||||||
from src.schemas.a2a.exceptions import *
|
TaskState,
|
||||||
from src.schemas.a2a.validators import *
|
TextPart,
|
||||||
|
FileContent,
|
||||||
|
FilePart,
|
||||||
|
DataPart,
|
||||||
|
Part,
|
||||||
|
Message,
|
||||||
|
TaskStatus,
|
||||||
|
Artifact,
|
||||||
|
Task,
|
||||||
|
TaskStatusUpdateEvent,
|
||||||
|
TaskArtifactUpdateEvent,
|
||||||
|
AuthenticationInfo,
|
||||||
|
PushNotificationConfig,
|
||||||
|
TaskIdParams,
|
||||||
|
TaskQueryParams,
|
||||||
|
TaskSendParams,
|
||||||
|
TaskPushNotificationConfig,
|
||||||
|
JSONRPCMessage,
|
||||||
|
JSONRPCRequest,
|
||||||
|
JSONRPCResponse,
|
||||||
|
JSONRPCError,
|
||||||
|
SendTaskRequest,
|
||||||
|
SendTaskResponse,
|
||||||
|
SendTaskStreamingRequest,
|
||||||
|
SendTaskStreamingResponse,
|
||||||
|
GetTaskRequest,
|
||||||
|
GetTaskResponse,
|
||||||
|
CancelTaskRequest,
|
||||||
|
CancelTaskResponse,
|
||||||
|
SetTaskPushNotificationRequest,
|
||||||
|
SetTaskPushNotificationResponse,
|
||||||
|
GetTaskPushNotificationRequest,
|
||||||
|
GetTaskPushNotificationResponse,
|
||||||
|
TaskResubscriptionRequest,
|
||||||
|
A2ARequest,
|
||||||
|
AgentProvider,
|
||||||
|
AgentCapabilities,
|
||||||
|
AgentAuthentication,
|
||||||
|
AgentSkill,
|
||||||
|
AgentCard,
|
||||||
|
)
|
||||||
|
|
||||||
|
from src.schemas.a2a.exceptions import (
|
||||||
|
JSONParseError,
|
||||||
|
InvalidRequestError,
|
||||||
|
MethodNotFoundError,
|
||||||
|
InvalidParamsError,
|
||||||
|
InternalError,
|
||||||
|
TaskNotFoundError,
|
||||||
|
TaskNotCancelableError,
|
||||||
|
PushNotificationNotSupportedError,
|
||||||
|
UnsupportedOperationError,
|
||||||
|
ContentTypeNotSupportedError,
|
||||||
|
A2AClientError,
|
||||||
|
A2AClientHTTPError,
|
||||||
|
A2AClientJSONError,
|
||||||
|
MissingAPIKeyError,
|
||||||
|
)
|
||||||
|
|
||||||
|
from src.schemas.a2a.validators import (
|
||||||
|
validate_base64,
|
||||||
|
validate_file_content,
|
||||||
|
validate_message_parts,
|
||||||
|
text_to_parts,
|
||||||
|
parts_to_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# From types
|
||||||
|
"TaskState",
|
||||||
|
"TextPart",
|
||||||
|
"FileContent",
|
||||||
|
"FilePart",
|
||||||
|
"DataPart",
|
||||||
|
"Part",
|
||||||
|
"Message",
|
||||||
|
"TaskStatus",
|
||||||
|
"Artifact",
|
||||||
|
"Task",
|
||||||
|
"TaskStatusUpdateEvent",
|
||||||
|
"TaskArtifactUpdateEvent",
|
||||||
|
"AuthenticationInfo",
|
||||||
|
"PushNotificationConfig",
|
||||||
|
"TaskIdParams",
|
||||||
|
"TaskQueryParams",
|
||||||
|
"TaskSendParams",
|
||||||
|
"TaskPushNotificationConfig",
|
||||||
|
"JSONRPCMessage",
|
||||||
|
"JSONRPCRequest",
|
||||||
|
"JSONRPCResponse",
|
||||||
|
"JSONRPCError",
|
||||||
|
"SendTaskRequest",
|
||||||
|
"SendTaskResponse",
|
||||||
|
"SendTaskStreamingRequest",
|
||||||
|
"SendTaskStreamingResponse",
|
||||||
|
"GetTaskRequest",
|
||||||
|
"GetTaskResponse",
|
||||||
|
"CancelTaskRequest",
|
||||||
|
"CancelTaskResponse",
|
||||||
|
"SetTaskPushNotificationRequest",
|
||||||
|
"SetTaskPushNotificationResponse",
|
||||||
|
"GetTaskPushNotificationRequest",
|
||||||
|
"GetTaskPushNotificationResponse",
|
||||||
|
"TaskResubscriptionRequest",
|
||||||
|
"A2ARequest",
|
||||||
|
"AgentProvider",
|
||||||
|
"AgentCapabilities",
|
||||||
|
"AgentAuthentication",
|
||||||
|
"AgentSkill",
|
||||||
|
"AgentCard",
|
||||||
|
# From exceptions
|
||||||
|
"JSONParseError",
|
||||||
|
"InvalidRequestError",
|
||||||
|
"MethodNotFoundError",
|
||||||
|
"InvalidParamsError",
|
||||||
|
"InternalError",
|
||||||
|
"TaskNotFoundError",
|
||||||
|
"TaskNotCancelableError",
|
||||||
|
"PushNotificationNotSupportedError",
|
||||||
|
"UnsupportedOperationError",
|
||||||
|
"ContentTypeNotSupportedError",
|
||||||
|
"A2AClientError",
|
||||||
|
"A2AClientHTTPError",
|
||||||
|
"A2AClientJSONError",
|
||||||
|
"MissingAPIKeyError",
|
||||||
|
# From validators
|
||||||
|
"validate_base64",
|
||||||
|
"validate_file_content",
|
||||||
|
"validate_message_parts",
|
||||||
|
"text_to_parts",
|
||||||
|
"parts_to_text",
|
||||||
|
]
|
||||||
|
329
src/services/a2a_agent.py
Normal file
329
src/services/a2a_agent.py
Normal file
@ -0,0 +1,329 @@
|
|||||||
|
from google.adk.agents import BaseAgent
|
||||||
|
from google.adk.agents.invocation_context import InvocationContext
|
||||||
|
from google.adk.events import Event
|
||||||
|
from google.genai.types import Content, Part
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
from src.schemas.a2a.types import (
|
||||||
|
GetTaskRequest,
|
||||||
|
SendTaskRequest,
|
||||||
|
Message,
|
||||||
|
TextPart,
|
||||||
|
TaskState,
|
||||||
|
)
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
|
||||||
|
class A2ACustomAgent(BaseAgent):
|
||||||
|
"""
|
||||||
|
Custom agent that implements the A2A protocol directly.
|
||||||
|
|
||||||
|
This agent implements the interaction with an external A2A service.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Field declarations for Pydantic
|
||||||
|
agent_card_url: str
|
||||||
|
poll_interval: float
|
||||||
|
max_wait_time: int
|
||||||
|
timeout: int
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
agent_card_url: str,
|
||||||
|
poll_interval: float = 1.0,
|
||||||
|
max_wait_time: int = 60,
|
||||||
|
timeout: int = 300,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the A2A agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Agent name
|
||||||
|
agent_card_url: A2A agent card URL
|
||||||
|
poll_interval: Status check interval (seconds)
|
||||||
|
max_wait_time: Maximum wait time for a task (seconds)
|
||||||
|
timeout: Maximum execution time (seconds)
|
||||||
|
"""
|
||||||
|
# Initialize base class
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
agent_card_url=agent_card_url,
|
||||||
|
poll_interval=poll_interval,
|
||||||
|
max_wait_time=max_wait_time,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"A2A agent initialized for URL: {agent_card_url}")
|
||||||
|
|
||||||
|
async def _run_async_impl(
|
||||||
|
self, ctx: InvocationContext
|
||||||
|
) -> AsyncGenerator[Event, None]:
|
||||||
|
"""
|
||||||
|
Implementation of the A2A protocol according to the Google ADK documentation.
|
||||||
|
|
||||||
|
This method follows the pattern of implementing custom agents,
|
||||||
|
sending the user's message to the A2A service and monitoring the response.
|
||||||
|
"""
|
||||||
|
# 1. Yield the initial event
|
||||||
|
yield Event(author=self.name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare the base URL for the A2A
|
||||||
|
url = self.agent_card_url
|
||||||
|
|
||||||
|
# Ensure that there is no /.well-known/agent.json in the url
|
||||||
|
if "/.well-known/agent.json" in url:
|
||||||
|
url = url.split("/.well-known/agent.json")[0]
|
||||||
|
|
||||||
|
# 2. Extract the user's message from the context
|
||||||
|
user_message = None
|
||||||
|
|
||||||
|
# Search for the user's message in the session events
|
||||||
|
if ctx.session and hasattr(ctx.session, "events") and ctx.session.events:
|
||||||
|
for event in reversed(ctx.session.events):
|
||||||
|
if event.author == "user" and event.content and event.content.parts:
|
||||||
|
user_message = event.content.parts[0].text
|
||||||
|
print("Message found in session events")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check in the session state if the message was not found in the events
|
||||||
|
if not user_message and ctx.session and ctx.session.state:
|
||||||
|
if "user_message" in ctx.session.state:
|
||||||
|
user_message = ctx.session.state["user_message"]
|
||||||
|
elif "message" in ctx.session.state:
|
||||||
|
user_message = ctx.session.state["message"]
|
||||||
|
|
||||||
|
# 3. Create and send the task to the A2A agent
|
||||||
|
print(f"Sending task to A2A agent: {user_message[:100]}...")
|
||||||
|
|
||||||
|
# Use the session ID as a stable identifier
|
||||||
|
session_id = (
|
||||||
|
str(ctx.session.id)
|
||||||
|
if ctx.session and hasattr(ctx.session, "id")
|
||||||
|
else str(uuid4())
|
||||||
|
)
|
||||||
|
task_id = str(uuid4())
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
formatted_message: Message = Message(
|
||||||
|
role="user",
|
||||||
|
parts=[TextPart(type="text", text=user_message)],
|
||||||
|
)
|
||||||
|
|
||||||
|
request: SendTaskRequest = SendTaskRequest(
|
||||||
|
params={
|
||||||
|
"message": formatted_message,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"id": task_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Request send task: {request.model_dump()}")
|
||||||
|
|
||||||
|
# REQUEST POST to url when jsonrpc is 2.0
|
||||||
|
task_result = await httpx.AsyncClient().post(
|
||||||
|
url, json=request.model_dump(), timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Task response: {task_result.json()}")
|
||||||
|
print(f"Task sent successfully, ID: {task_id}")
|
||||||
|
|
||||||
|
yield Event(
|
||||||
|
author=self.name,
|
||||||
|
content=Content(
|
||||||
|
role="agent", parts=[Part(text="Processing request...")]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error sending request: {str(e)}"
|
||||||
|
print(error_msg)
|
||||||
|
print(f"Error type: {type(e).__name__}")
|
||||||
|
print(f"Error details: {str(e)}")
|
||||||
|
|
||||||
|
yield Event(
|
||||||
|
author=self.name,
|
||||||
|
content=Content(role="agent", parts=[Part(text=str(e))]),
|
||||||
|
)
|
||||||
|
yield Event(author=self.name) # Final event
|
||||||
|
return
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
while time.time() - start_time < self.timeout:
|
||||||
|
try:
|
||||||
|
# Check current status
|
||||||
|
request: GetTaskRequest = GetTaskRequest(params={"id": task_id})
|
||||||
|
|
||||||
|
task_status_response = await httpx.AsyncClient().post(
|
||||||
|
url, json=request.model_dump(), timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Response get task: {task_status_response.json()}")
|
||||||
|
|
||||||
|
task_status = task_status_response.json()
|
||||||
|
|
||||||
|
if "result" not in task_status or not task_status["result"]:
|
||||||
|
await asyncio.sleep(self.poll_interval)
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_state = None
|
||||||
|
if (
|
||||||
|
"status" in task_status["result"]
|
||||||
|
and task_status["result"]["status"]
|
||||||
|
):
|
||||||
|
if "state" in task_status["result"]["status"]:
|
||||||
|
current_state = task_status["result"]["status"]["state"]
|
||||||
|
|
||||||
|
print(f"Task status {task_id}: {current_state}")
|
||||||
|
|
||||||
|
# Check if the task was completed
|
||||||
|
if current_state in [
|
||||||
|
TaskState.COMPLETED,
|
||||||
|
TaskState.FAILED,
|
||||||
|
TaskState.CANCELED,
|
||||||
|
]:
|
||||||
|
if current_state == TaskState.COMPLETED:
|
||||||
|
# Extract the response
|
||||||
|
if (
|
||||||
|
"status" in task_status["result"]
|
||||||
|
and "message" in task_status["result"]["status"]
|
||||||
|
and "parts"
|
||||||
|
in task_status["result"]["status"]["message"]
|
||||||
|
):
|
||||||
|
|
||||||
|
# Convert A2A parts to ADK
|
||||||
|
response_parts = []
|
||||||
|
for part in task_status["result"]["status"]["message"][
|
||||||
|
"parts"
|
||||||
|
]:
|
||||||
|
if "text" in part and part["text"]:
|
||||||
|
response_parts.append(Part(text=part["text"]))
|
||||||
|
elif "data" in part:
|
||||||
|
try:
|
||||||
|
json_text = json.dumps(
|
||||||
|
part["data"],
|
||||||
|
ensure_ascii=False,
|
||||||
|
indent=2,
|
||||||
|
)
|
||||||
|
response_parts.append(
|
||||||
|
Part(text=f"```json\n{json_text}\n```")
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
response_parts.append(
|
||||||
|
Part(text="[Unserializable data]")
|
||||||
|
)
|
||||||
|
|
||||||
|
if response_parts:
|
||||||
|
yield Event(
|
||||||
|
author=self.name,
|
||||||
|
content=Content(
|
||||||
|
role="agent", parts=response_parts
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield Event(
|
||||||
|
author=self.name,
|
||||||
|
content=Content(
|
||||||
|
role="agent",
|
||||||
|
parts=[
|
||||||
|
Part(text="Empty response from agent.")
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield Event(
|
||||||
|
author=self.name,
|
||||||
|
content=Content(
|
||||||
|
role="agent",
|
||||||
|
parts=[
|
||||||
|
Part(
|
||||||
|
text="Task completed, but no response message."
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif current_state == TaskState.FAILED:
|
||||||
|
yield Event(
|
||||||
|
author=self.name,
|
||||||
|
content=Content(
|
||||||
|
role="agent",
|
||||||
|
parts=[
|
||||||
|
Part(text="The task failed during processing.")
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else: # CANCELED
|
||||||
|
yield Event(
|
||||||
|
author=self.name,
|
||||||
|
content=Content(
|
||||||
|
role="agent",
|
||||||
|
parts=[Part(text="The task was canceled.")],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store in the session state for future reference
|
||||||
|
if ctx.session:
|
||||||
|
try:
|
||||||
|
ctx.session.state["a2a_task_result"] = task_status[
|
||||||
|
"result"
|
||||||
|
]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
break # Exit the loop of checking
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error checking task: {str(e)}")
|
||||||
|
|
||||||
|
# If the timeout was exceeded, inform the user
|
||||||
|
if time.time() - start_time > self.max_wait_time:
|
||||||
|
yield Event(
|
||||||
|
author=self.name,
|
||||||
|
content=Content(
|
||||||
|
role="agent",
|
||||||
|
parts=[Part(text=f"Error checking task: {str(e)}")],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Wait before the next check
|
||||||
|
await asyncio.sleep(self.poll_interval)
|
||||||
|
|
||||||
|
# If the timeout was exceeded
|
||||||
|
if time.time() - start_time >= self.timeout:
|
||||||
|
yield Event(
|
||||||
|
author=self.name,
|
||||||
|
content=Content(
|
||||||
|
role="agent",
|
||||||
|
parts=[
|
||||||
|
Part(
|
||||||
|
text="The operation exceeded the timeout. Please try again later."
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Handle any uncaught error
|
||||||
|
print(f"Error executing A2A agent: {str(e)}")
|
||||||
|
yield Event(
|
||||||
|
author=self.name,
|
||||||
|
content=Content(
|
||||||
|
role="agent",
|
||||||
|
parts=[Part(text=f"Error interacting with A2A agent: {str(e)}")],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Ensure that the final event is always generated
|
||||||
|
yield Event(author=self.name)
|
@ -5,10 +5,8 @@ This module implements a JSON-RPC compatible server for the A2A protocol,
|
|||||||
that manages agent tasks, streaming events and push notifications.
|
that manages agent tasks, streaming events and push notifications.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -16,18 +14,14 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
Callable,
|
|
||||||
Union,
|
Union,
|
||||||
AsyncIterable,
|
AsyncIterable,
|
||||||
)
|
)
|
||||||
import httpx
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse, Response
|
from fastapi.responses import JSONResponse, StreamingResponse, Response
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from src.schemas.a2a.types import A2ARequest
|
from src.schemas.a2a.types import A2ARequest
|
||||||
from src.services.agent_runner import run_agent
|
|
||||||
from src.services.a2a_integration_service import (
|
from src.services.a2a_integration_service import (
|
||||||
AgentRunnerAdapter,
|
AgentRunnerAdapter,
|
||||||
StreamingServiceAdapter,
|
StreamingServiceAdapter,
|
||||||
@ -42,9 +36,7 @@ from src.schemas.a2a.types import (
|
|||||||
SetTaskPushNotificationRequest,
|
SetTaskPushNotificationRequest,
|
||||||
GetTaskPushNotificationRequest,
|
GetTaskPushNotificationRequest,
|
||||||
TaskResubscriptionRequest,
|
TaskResubscriptionRequest,
|
||||||
TaskSendParams,
|
|
||||||
)
|
)
|
||||||
from src.utils.a2a_utils import are_modalities_compatible
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -579,7 +571,9 @@ class A2AServer:
|
|||||||
"error": {
|
"error": {
|
||||||
"code": -32601,
|
"code": -32601,
|
||||||
"message": "Method not found",
|
"message": "Method not found",
|
||||||
"data": {"detail": f"Method not supported"},
|
"data": {
|
||||||
|
"detail": f"Method not supported: {method}"
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -6,12 +6,9 @@ including execution, streaming, push notifications, status queries, and cancella
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional, Union, AsyncIterable
|
from typing import Any, Dict, Union, AsyncIterable
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from src.schemas.a2a.exceptions import (
|
from src.schemas.a2a.exceptions import (
|
||||||
TaskNotFoundError,
|
TaskNotFoundError,
|
||||||
@ -23,8 +20,6 @@ from src.schemas.a2a.exceptions import (
|
|||||||
|
|
||||||
from src.schemas.a2a.types import (
|
from src.schemas.a2a.types import (
|
||||||
JSONRPCResponse,
|
JSONRPCResponse,
|
||||||
TaskIdParams,
|
|
||||||
TaskQueryParams,
|
|
||||||
GetTaskRequest,
|
GetTaskRequest,
|
||||||
SendTaskRequest,
|
SendTaskRequest,
|
||||||
CancelTaskRequest,
|
CancelTaskRequest,
|
||||||
@ -55,9 +50,6 @@ from src.services.redis_cache_service import RedisCacheService
|
|||||||
from src.utils.a2a_utils import (
|
from src.utils.a2a_utils import (
|
||||||
are_modalities_compatible,
|
are_modalities_compatible,
|
||||||
new_incompatible_types_error,
|
new_incompatible_types_error,
|
||||||
new_not_implemented_error,
|
|
||||||
create_task_id,
|
|
||||||
format_error_response,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
from google.adk.agents.llm_agent import LlmAgent
|
from google.adk.agents.llm_agent import LlmAgent
|
||||||
from google.adk.agents import SequentialAgent, ParallelAgent, LoopAgent
|
from google.adk.agents import SequentialAgent, ParallelAgent, LoopAgent, BaseAgent
|
||||||
from google.adk.models.lite_llm import LiteLlm
|
from google.adk.models.lite_llm import LiteLlm
|
||||||
from src.utils.logger import setup_logger
|
from src.utils.logger import setup_logger
|
||||||
from src.core.exceptions import AgentNotFoundError
|
from src.core.exceptions import AgentNotFoundError
|
||||||
from src.services.agent_service import get_agent
|
from src.services.agent_service import get_agent
|
||||||
from src.services.custom_tools import CustomToolBuilder
|
from src.services.custom_tools import CustomToolBuilder
|
||||||
from src.services.mcp_service import MCPService
|
from src.services.mcp_service import MCPService
|
||||||
|
from src.services.a2a_agent import A2ACustomAgent
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from google.adk.tools import load_memory
|
from google.adk.tools import load_memory
|
||||||
@ -87,12 +88,19 @@ class AgentBuilder:
|
|||||||
if agent is None:
|
if agent is None:
|
||||||
raise AgentNotFoundError(f"Agent with ID {sub_agent_id} not found")
|
raise AgentNotFoundError(f"Agent with ID {sub_agent_id} not found")
|
||||||
|
|
||||||
if agent.type != "llm":
|
if agent.type == "llm":
|
||||||
raise ValueError(
|
sub_agent, exit_stack = await self._create_llm_agent(agent)
|
||||||
f"Agent {agent.name} (ID: {agent.id}) is not an LLM agent"
|
elif agent.type == "a2a":
|
||||||
)
|
sub_agent, exit_stack = await self.build_a2a_agent(agent)
|
||||||
|
elif agent.type == "sequential":
|
||||||
|
sub_agent, exit_stack = await self.build_composite_agent(agent)
|
||||||
|
elif agent.type == "parallel":
|
||||||
|
sub_agent, exit_stack = await self.build_composite_agent(agent)
|
||||||
|
elif agent.type == "loop":
|
||||||
|
sub_agent, exit_stack = await self.build_composite_agent(agent)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid agent type: {agent.type}")
|
||||||
|
|
||||||
sub_agent, exit_stack = await self._create_llm_agent(agent)
|
|
||||||
sub_agents.append((sub_agent, exit_stack))
|
sub_agents.append((sub_agent, exit_stack))
|
||||||
|
|
||||||
return sub_agents
|
return sub_agents
|
||||||
@ -116,6 +124,41 @@ class AgentBuilder:
|
|||||||
|
|
||||||
return root_llm_agent, exit_stack
|
return root_llm_agent, exit_stack
|
||||||
|
|
||||||
|
async def build_a2a_agent(
|
||||||
|
self, root_agent
|
||||||
|
) -> Tuple[BaseAgent, Optional[AsyncExitStack]]:
|
||||||
|
"""Build an A2A agent with its sub-agents."""
|
||||||
|
logger.info(f"Creating A2A agent from {root_agent.agent_card_url}")
|
||||||
|
|
||||||
|
if not root_agent.agent_card_url:
|
||||||
|
raise ValueError("agent_card_url is required for a2a agents")
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = root_agent.config or {}
|
||||||
|
poll_interval = config.get("poll_interval", 1.0)
|
||||||
|
max_wait_time = config.get("max_wait_time", 60)
|
||||||
|
timeout = config.get("timeout", 300)
|
||||||
|
|
||||||
|
a2a_agent = A2ACustomAgent(
|
||||||
|
name=root_agent.name,
|
||||||
|
agent_card_url=root_agent.agent_card_url,
|
||||||
|
poll_interval=poll_interval,
|
||||||
|
max_wait_time=max_wait_time,
|
||||||
|
timeout=timeout,
|
||||||
|
description=root_agent.description
|
||||||
|
or f"A2A Agent for {root_agent.name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"A2A agent created successfully: {root_agent.name} ({root_agent.agent_card_url})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return a2a_agent, None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error building A2A agent: {str(e)}")
|
||||||
|
raise ValueError(f"Error building A2A agent: {str(e)}")
|
||||||
|
|
||||||
async def build_composite_agent(
|
async def build_composite_agent(
|
||||||
self, root_agent
|
self, root_agent
|
||||||
) -> Tuple[SequentialAgent | ParallelAgent | LoopAgent, Optional[AsyncExitStack]]:
|
) -> Tuple[SequentialAgent | ParallelAgent | LoopAgent, Optional[AsyncExitStack]]:
|
||||||
@ -161,13 +204,14 @@ class AgentBuilder:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid agent type: {root_agent.type}")
|
raise ValueError(f"Invalid agent type: {root_agent.type}")
|
||||||
|
|
||||||
async def build_agent(
|
async def build_agent(self, root_agent) -> Tuple[
|
||||||
self, root_agent
|
LlmAgent | SequentialAgent | ParallelAgent | LoopAgent | A2ACustomAgent,
|
||||||
) -> Tuple[
|
Optional[AsyncExitStack],
|
||||||
LlmAgent | SequentialAgent | ParallelAgent | LoopAgent, Optional[AsyncExitStack]
|
|
||||||
]:
|
]:
|
||||||
"""Build the appropriate agent based on the type of the root agent."""
|
"""Build the appropriate agent based on the type of the root agent."""
|
||||||
if root_agent.type == "llm":
|
if root_agent.type == "llm":
|
||||||
return await self.build_llm_agent(root_agent)
|
return await self.build_llm_agent(root_agent)
|
||||||
|
elif root_agent.type == "a2a":
|
||||||
|
return await self.build_a2a_agent(root_agent)
|
||||||
else:
|
else:
|
||||||
return await self.build_composite_agent(root_agent)
|
return await self.build_composite_agent(root_agent)
|
||||||
|
@ -17,7 +17,7 @@ import jwt
|
|||||||
from jwt import PyJWK, PyJWKClient
|
from jwt import PyJWK, PyJWKClient
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
AUTH_HEADER_PREFIX = "Bearer "
|
AUTH_HEADER_PREFIX = "Bearer "
|
||||||
|
@ -7,11 +7,10 @@ push notification configurations, and other A2A-related data.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional
|
||||||
import asyncio
|
import asyncio
|
||||||
import redis.asyncio as aioredis
|
import redis.asyncio as aioredis
|
||||||
from redis.exceptions import RedisError
|
from src.config.redis import get_redis_config
|
||||||
from src.config.redis import get_redis_config, get_a2a_config
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user