refactor(a2a): restructure A2A routes and services for improved task management and API key verification

This commit is contained in:
Davidson Gomes 2025-05-05 21:22:53 -03:00
parent a0f984ae21
commit ec9dc07d71
16 changed files with 819 additions and 4244 deletions

View File

@ -6,29 +6,18 @@ This module implements the standard A2A routes according to the specification.
import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException, status, Header, Request
import json
from fastapi import APIRouter, Depends, Header, Request, HTTPException
from sqlalchemy.orm import Session
from starlette.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from src.models.models import Agent
from src.config.database import get_db
from src.services import agent_service
from src.services import (
RedisCacheService,
AgentRunnerAdapter,
StreamingServiceAdapter,
create_agent_card_from_agent,
from src.services.a2a_task_manager import (
A2ATaskManager,
A2AService,
)
from src.services.a2a_task_manager_service import A2ATaskManager
from src.services.a2a_server_service import A2AServer
from src.services.agent_runner import run_agent
from src.services.service_providers import (
session_service,
artifacts_service,
memory_service,
)
from src.services.push_notification_service import push_notification_service
from src.services.push_notification_auth_service import push_notification_auth
from src.services.streaming_service import StreamingService
logger = logging.getLogger(__name__)
@ -43,85 +32,27 @@ router = APIRouter(
},
)
streaming_service = StreamingService()
redis_cache_service = RedisCacheService()
streaming_adapter = StreamingServiceAdapter(streaming_service)
_task_manager_cache = {}
_agent_runner_cache = {}
def get_a2a_service(db: Session = Depends(get_db)):
task_manager = A2ATaskManager(db)
return A2AService(db, task_manager)
def get_agent_runner_adapter(db=None, reuse=True, agent_id=None):
"""
Get or create an agent runner adapter.
async def verify_api_key(db: Session, x_api_key: str) -> bool:
"""Verifies the API key."""
if not x_api_key:
raise HTTPException(status_code=401, detail="API key not provided")
Args:
db: Database session
reuse: Whether to reuse an existing instance
agent_id: Agent ID to use as cache key
Returns:
Agent runner adapter instance
"""
cache_key = str(agent_id) if agent_id else "default"
if reuse and cache_key in _agent_runner_cache:
adapter = _agent_runner_cache[cache_key]
if db is not None:
adapter.db = db
return adapter
adapter = AgentRunnerAdapter(
agent_runner_func=run_agent,
session_service=session_service,
artifacts_service=artifacts_service,
memory_service=memory_service,
db=db,
agent = (
db.query(Agent)
.filter(Agent.config.has_key("api_key"))
.filter(Agent.config["api_key"].astext == x_api_key)
.first()
)
if reuse:
_agent_runner_cache[cache_key] = adapter
return adapter
def get_task_manager(agent_id, db=None, reuse=True, operation_type="query"):
cache_key = str(agent_id)
if operation_type == "query":
if cache_key in _task_manager_cache:
task_manager = _task_manager_cache[cache_key]
task_manager.db = db
return task_manager
return A2ATaskManager(
redis_cache=redis_cache_service,
agent_runner=None,
streaming_service=streaming_adapter,
push_notification_service=push_notification_service,
db=db,
)
if reuse and cache_key in _task_manager_cache:
task_manager = _task_manager_cache[cache_key]
task_manager.db = db
return task_manager
# Create new
agent_runner_adapter = get_agent_runner_adapter(
db=db, reuse=reuse, agent_id=agent_id
)
task_manager = A2ATaskManager(
redis_cache=redis_cache_service,
agent_runner=agent_runner_adapter,
streaming_service=streaming_adapter,
push_notification_service=push_notification_service,
db=db,
)
_task_manager_cache[cache_key] = task_manager
return task_manager
if not agent:
raise HTTPException(status_code=401, detail="Invalid API key")
return True
@router.post("/{agent_id}")
@ -130,156 +61,42 @@ async def process_a2a_request(
request: Request,
x_api_key: str = Header(None, alias="x-api-key"),
db: Session = Depends(get_db),
a2a_service: A2AService = Depends(get_a2a_service),
):
"""
Main endpoint for processing JSON-RPC requests of the A2A protocol.
"""Processes an A2A request."""
# Verify the API key
if not verify_api_key(db, x_api_key):
raise HTTPException(status_code=401, detail="Invalid API key")
This endpoint processes all JSON-RPC methods of the A2A protocol, including:
- tasks/send: Sending tasks
- tasks/sendSubscribe: Sending tasks with streaming
- tasks/get: Querying task status
- tasks/cancel: Cancelling tasks
- tasks/pushNotification/set: Setting push notifications
- tasks/pushNotification/get: Querying push notification configurations
- tasks/resubscribe: Resubscribing to receive task updates
Args:
agent_id: Agent ID
request: HTTP request with JSON-RPC payload
x_api_key: API key for authentication
db: Database session
Returns:
JSON-RPC response or streaming (SSE) depending on the method
"""
# Process the request
try:
try:
body = await request.json()
method = body.get("method", "unknown")
request_id = body.get("id") # Extract request ID to ensure it's preserved
request_body = await request.json()
result = await a2a_service.process_request(agent_id, request_body)
# Extrair sessionId do params se for tasks/send
session_id = None
if method == "tasks/send" and body.get("params"):
session_id = body.get("params", {}).get("sessionId")
logger.info(f"Extracted sessionId from request: {session_id}")
# If the response is a streaming response, return as EventSourceResponse
if hasattr(result, "__aiter__"):
is_query_request = method in [
"tasks/get",
"tasks/cancel",
"tasks/pushNotification/get",
"tasks/resubscribe",
]
async def event_generator():
async for item in result:
if hasattr(item, "model_dump_json"):
yield {"data": item.model_dump_json(exclude_none=True)}
else:
yield {"data": json.dumps(item)}
reuse_components = is_query_request
except Exception as e:
logger.error(f"Error reading request body: {e}")
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"id": None,
"error": {
"code": -32700,
"message": f"Parse error: {str(e)}",
"data": None,
},
},
)
# Verify if the agent exists
agent = agent_service.get_agent(db, agent_id)
if agent is None:
logger.warning(f"Agent not found: {agent_id}")
return JSONResponse(
status_code=404,
content={
"jsonrpc": "2.0",
"id": request_id, # Use the extracted request ID
"error": {"code": 404, "message": "Agent not found", "data": None},
},
)
# Verify API key
agent_config = agent.config
if x_api_key and agent_config.get("api_key") != x_api_key:
logger.warning(f"Invalid API Key for agent {agent_id}")
return JSONResponse(
status_code=401,
content={
"jsonrpc": "2.0",
"id": request_id, # Use the extracted request ID
"error": {"code": 401, "message": "Invalid API key", "data": None},
},
)
a2a_task_manager = get_task_manager(
agent_id,
db=db,
reuse=reuse_components,
operation_type="query" if is_query_request else "execution",
)
a2a_server = A2AServer(task_manager=a2a_task_manager)
agent_card = create_agent_card_from_agent(agent, db)
a2a_server.agent_card = agent_card
# Verify JSON-RPC format
if not body.get("jsonrpc") or body.get("jsonrpc") != "2.0":
logger.error(f"Invalid JSON-RPC format: {body.get('jsonrpc')}")
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"id": request_id, # Use the extracted request ID
"error": {
"code": -32600,
"message": "Invalid Request: jsonrpc must be '2.0'",
"data": None,
},
},
)
if not body.get("method"):
logger.error("Method not specified in request")
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"id": request_id, # Use the extracted request ID
"error": {
"code": -32600,
"message": "Invalid Request: method is required",
"data": None,
},
},
)
# Processar a requisição normalmente
return await a2a_server.process_request(request, agent_id=str(agent_id), db=db)
return EventSourceResponse(event_generator())
# Otherwise, return as JSONResponse
if hasattr(result, "model_dump"):
return JSONResponse(result.model_dump(exclude_none=True))
return JSONResponse(result)
except Exception as e:
logger.error(f"Error processing A2A request: {str(e)}", exc_info=True)
# Try to extract request ID from the body, if available
request_id = None
try:
body = await request.json()
request_id = body.get("id")
except:
pass
logger.error(f"Error processing A2A request: {e}")
return JSONResponse(
status_code=500,
content={
"jsonrpc": "2.0",
"id": request_id, # Use the extracted request ID or None
"error": {
"code": -32603,
"message": "Internal server error",
"data": {"detail": str(e)},
},
"id": None,
"error": {"code": -32603, "message": "Internal server error"},
},
)
@ -289,81 +106,17 @@ async def get_agent_card(
agent_id: uuid.UUID,
request: Request,
db: Session = Depends(get_db),
a2a_service: A2AService = Depends(get_a2a_service),
):
"""
Endpoint to get the Agent Card in the .well-known format of the A2A protocol.
This endpoint returns the agent information in the standard A2A format,
including capabilities, authentication information, and skills.
Args:
agent_id: Agent ID
request: HTTP request
db: Database session
Returns:
Agent Card in JSON format
"""
"""Gets the agent card for the specified agent."""
try:
agent = agent_service.get_agent(db, agent_id)
if agent is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found"
)
agent_card = create_agent_card_from_agent(agent, db)
a2a_task_manager = get_task_manager(agent_id, db=db, reuse=True)
a2a_server = A2AServer(task_manager=a2a_task_manager)
# Configure the A2A server with the agent card
a2a_server.agent_card = agent_card
# Use the A2A server to deliver the agent card, ensuring protocol compatibility
return await a2a_server.get_agent_card(request, db=db)
agent_card = a2a_service.get_agent_card(agent_id)
if hasattr(agent_card, "model_dump"):
return JSONResponse(agent_card.model_dump(exclude_none=True))
return JSONResponse(agent_card)
except Exception as e:
logger.error(f"Error generating agent card: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error generating agent card",
)
@router.get("/{agent_id}/.well-known/jwks.json")
async def get_jwks(
agent_id: uuid.UUID,
request: Request,
db: Session = Depends(get_db),
):
"""
Endpoint to get the public JWKS keys for verifying the authenticity
of push notifications.
Clients can use these keys to verify the authenticity of received notifications.
Args:
agent_id: Agent ID
request: HTTP request
db: Database session
Returns:
JSON with the public keys in JWKS format
"""
try:
# Verify if the agent exists
agent = agent_service.get_agent(db, agent_id)
if agent is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found"
)
# Return the public keys
return push_notification_auth.handle_jwks_endpoint(request)
except Exception as e:
logger.error(f"Error obtaining JWKS: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error obtaining JWKS",
logger.error(f"Error getting agent card: {e}")
return JSONResponse(
status_code=404,
content={"error": f"Agent not found: {str(e)}"},
)

View File

@ -1,140 +0,0 @@
"""
A2A (Agent-to-Agent) schema package.
This package contains Pydantic schema definitions for the A2A protocol.
"""
from src.schemas.a2a.types import (
TaskState,
TextPart,
FileContent,
FilePart,
DataPart,
Part,
Message,
TaskStatus,
Artifact,
Task,
TaskStatusUpdateEvent,
TaskArtifactUpdateEvent,
AuthenticationInfo,
PushNotificationConfig,
TaskIdParams,
TaskQueryParams,
TaskSendParams,
TaskPushNotificationConfig,
JSONRPCMessage,
JSONRPCRequest,
JSONRPCResponse,
JSONRPCError,
SendTaskRequest,
SendTaskResponse,
SendTaskStreamingRequest,
SendTaskStreamingResponse,
GetTaskRequest,
GetTaskResponse,
CancelTaskRequest,
CancelTaskResponse,
SetTaskPushNotificationRequest,
SetTaskPushNotificationResponse,
GetTaskPushNotificationRequest,
GetTaskPushNotificationResponse,
TaskResubscriptionRequest,
A2ARequest,
AgentProvider,
AgentCapabilities,
AgentAuthentication,
AgentSkill,
AgentCard,
)
from src.schemas.a2a.exceptions import (
JSONParseError,
InvalidRequestError,
MethodNotFoundError,
InvalidParamsError,
InternalError,
TaskNotFoundError,
TaskNotCancelableError,
PushNotificationNotSupportedError,
UnsupportedOperationError,
ContentTypeNotSupportedError,
A2AClientError,
A2AClientHTTPError,
A2AClientJSONError,
MissingAPIKeyError,
)
from src.schemas.a2a.validators import (
validate_base64,
validate_file_content,
validate_message_parts,
text_to_parts,
parts_to_text,
)
__all__ = [
# From types
"TaskState",
"TextPart",
"FileContent",
"FilePart",
"DataPart",
"Part",
"Message",
"TaskStatus",
"Artifact",
"Task",
"TaskStatusUpdateEvent",
"TaskArtifactUpdateEvent",
"AuthenticationInfo",
"PushNotificationConfig",
"TaskIdParams",
"TaskQueryParams",
"TaskSendParams",
"TaskPushNotificationConfig",
"JSONRPCMessage",
"JSONRPCRequest",
"JSONRPCResponse",
"JSONRPCError",
"SendTaskRequest",
"SendTaskResponse",
"SendTaskStreamingRequest",
"SendTaskStreamingResponse",
"GetTaskRequest",
"GetTaskResponse",
"CancelTaskRequest",
"CancelTaskResponse",
"SetTaskPushNotificationRequest",
"SetTaskPushNotificationResponse",
"GetTaskPushNotificationRequest",
"GetTaskPushNotificationResponse",
"TaskResubscriptionRequest",
"A2ARequest",
"AgentProvider",
"AgentCapabilities",
"AgentAuthentication",
"AgentSkill",
"AgentCard",
# From exceptions
"JSONParseError",
"InvalidRequestError",
"MethodNotFoundError",
"InvalidParamsError",
"InternalError",
"TaskNotFoundError",
"TaskNotCancelableError",
"PushNotificationNotSupportedError",
"UnsupportedOperationError",
"ContentTypeNotSupportedError",
"A2AClientError",
"A2AClientHTTPError",
"A2AClientJSONError",
"MissingAPIKeyError",
# From validators
"validate_base64",
"validate_file_content",
"validate_message_parts",
"text_to_parts",
"parts_to_text",
]

View File

@ -1,147 +0,0 @@
"""
A2A (Agent-to-Agent) protocol exception definitions.
This module contains error types and exceptions for the A2A protocol.
"""
from src.schemas.a2a.types import JSONRPCError
class JSONParseError(JSONRPCError):
"""
Error raised when JSON parsing fails.
"""
code: int = -32700
message: str = "Invalid JSON payload"
data: object | None = None
class InvalidRequestError(JSONRPCError):
"""
Error raised when request validation fails.
"""
code: int = -32600
message: str = "Request payload validation error"
data: object | None = None
class MethodNotFoundError(JSONRPCError):
"""
Error raised when the requested method is not found.
"""
code: int = -32601
message: str = "Method not found"
data: None = None
class InvalidParamsError(JSONRPCError):
"""
Error raised when the parameters are invalid.
"""
code: int = -32602
message: str = "Invalid parameters"
data: object | None = None
class InternalError(JSONRPCError):
"""
Error raised when an internal error occurs.
"""
code: int = -32603
message: str = "Internal error"
data: object | None = None
class TaskNotFoundError(JSONRPCError):
"""
Error raised when the requested task is not found.
"""
code: int = -32001
message: str = "Task not found"
data: None = None
class TaskNotCancelableError(JSONRPCError):
"""
Error raised when a task cannot be canceled.
"""
code: int = -32002
message: str = "Task cannot be canceled"
data: None = None
class PushNotificationNotSupportedError(JSONRPCError):
"""
Error raised when push notifications are not supported.
"""
code: int = -32003
message: str = "Push Notification is not supported"
data: None = None
class UnsupportedOperationError(JSONRPCError):
"""
Error raised when an operation is not supported.
"""
code: int = -32004
message: str = "This operation is not supported"
data: None = None
class ContentTypeNotSupportedError(JSONRPCError):
"""
Error raised when content types are incompatible.
"""
code: int = -32005
message: str = "Incompatible content types"
data: None = None
# Client exceptions
class A2AClientError(Exception):
"""
Base exception for A2A client errors.
"""
pass
class A2AClientHTTPError(A2AClientError):
"""
Exception for HTTP errors in A2A client.
"""
def __init__(self, status_code: int, message: str):
self.status_code = status_code
self.message = message
super().__init__(f"HTTP Error {status_code}: {message}")
class A2AClientJSONError(A2AClientError):
"""
Exception for JSON errors in A2A client.
"""
def __init__(self, message: str):
self.message = message
super().__init__(f"JSON Error: {message}")
class MissingAPIKeyError(Exception):
"""
Exception for missing API key.
"""
pass

View File

@ -1,124 +0,0 @@
"""
A2A (Agent-to-Agent) protocol validators.
This module contains validators for the A2A protocol data.
"""
from typing import List
import base64
import re
from pydantic import ValidationError
import logging
from src.schemas.a2a.types import Part, TextPart, FilePart, DataPart, FileContent
logger = logging.getLogger(__name__)
def validate_base64(value: str) -> bool:
"""
Validates if a string is valid base64.
Args:
value: String to validate
Returns:
True if valid base64, False otherwise
"""
try:
if not value:
return False
# Check if the string has base64 characters only
pattern = r"^[A-Za-z0-9+/]+={0,2}$"
if not re.match(pattern, value):
return False
# Try to decode
base64.b64decode(value)
return True
except Exception as e:
logger.warning(f"Invalid base64 string: {e}")
return False
def validate_file_content(file_content: FileContent) -> bool:
"""
Validates file content.
Args:
file_content: FileContent to validate
Returns:
True if valid, False otherwise
"""
try:
if file_content.bytes is not None:
return validate_base64(file_content.bytes)
elif file_content.uri is not None:
# Basic URL validation
pattern = r"^https?://.+"
return bool(re.match(pattern, file_content.uri))
return False
except Exception as e:
logger.warning(f"Invalid file content: {e}")
return False
def validate_message_parts(parts: List[Part]) -> bool:
"""
Validates all parts in a message.
Args:
parts: List of parts to validate
Returns:
True if all parts are valid, False otherwise
"""
try:
for part in parts:
if isinstance(part, TextPart):
if not part.text or not isinstance(part.text, str):
return False
elif isinstance(part, FilePart):
if not validate_file_content(part.file):
return False
elif isinstance(part, DataPart):
if not part.data or not isinstance(part.data, dict):
return False
else:
return False
return True
except (ValidationError, Exception) as e:
logger.warning(f"Invalid message parts: {e}")
return False
def text_to_parts(text: str) -> List[Part]:
"""
Converts a plain text to a list of message parts.
Args:
text: Plain text to convert
Returns:
List containing a single TextPart
"""
return [TextPart(text=text)]
def parts_to_text(parts: List[Part]) -> str:
"""
Extracts text from a list of message parts.
Args:
parts: List of parts to extract text from
Returns:
Concatenated text from all text parts
"""
text = ""
for part in parts:
if isinstance(part, TextPart):
text += part.text
# Could add handling for other part types here
return text

View File

@ -1,31 +1,20 @@
"""
A2A (Agent-to-Agent) protocol type definitions.
from datetime import datetime
from enum import Enum
from typing import Annotated, Any, Literal, TypeVar
from uuid import uuid4
from typing_extensions import Self
This module contains Pydantic schema definitions for the A2A protocol.
"""
from typing import Union, Any, List, Optional, Annotated, Literal
from pydantic import (
BaseModel,
ConfigDict,
Field,
TypeAdapter,
field_serializer,
model_validator,
ConfigDict,
)
from datetime import datetime
from uuid import uuid4
from enum import Enum
from typing_extensions import Self
class TaskState(str, Enum):
"""
Enum for the state of a task in the A2A protocol.
States follow the A2A protocol specification.
"""
SUBMITTED = "submitted"
WORKING = "working"
INPUT_REQUIRED = "input-required"
@ -36,22 +25,12 @@ class TaskState(str, Enum):
class TextPart(BaseModel):
"""
Represents a text part in a message.
"""
type: Literal["text"] = "text"
text: str
metadata: dict[str, Any] | None = None
class FileContent(BaseModel):
"""
Represents file content in a file part.
Either bytes or uri must be provided, but not both.
"""
name: str | None = None
mimeType: str | None = None
bytes: str | None = None
@ -59,9 +38,6 @@ class FileContent(BaseModel):
@model_validator(mode="after")
def check_content(self) -> Self:
"""
Validates that either bytes or uri is present, but not both.
"""
if not (self.bytes or self.uri):
raise ValueError("Either 'bytes' or 'uri' must be present in the file data")
if self.bytes and self.uri:
@ -72,65 +48,40 @@ class FileContent(BaseModel):
class FilePart(BaseModel):
"""
Represents a file part in a message.
"""
type: Literal["file"] = "file"
file: FileContent
metadata: dict[str, Any] | None = None
class DataPart(BaseModel):
"""
Represents a data part in a message.
"""
type: Literal["data"] = "data"
data: dict[str, Any]
metadata: dict[str, Any] | None = None
Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator="type")]
Part = Annotated[TextPart | FilePart | DataPart, Field(discriminator="type")]
class Message(BaseModel):
"""
Represents a message in the A2A protocol.
A message consists of a role and one or more parts.
"""
role: Literal["user", "agent"]
parts: List[Part]
parts: list[Part]
metadata: dict[str, Any] | None = None
class TaskStatus(BaseModel):
"""
Represents the status of a task.
"""
state: TaskState
message: Message | None = None
timestamp: datetime = Field(default_factory=datetime.now)
@field_serializer("timestamp")
def serialize_dt(self, dt: datetime, _info):
"""
Serializes datetime to ISO format.
"""
return dt.isoformat()
class Artifact(BaseModel):
"""
Represents an artifact produced by an agent.
"""
name: str | None = None
description: str | None = None
parts: List[Part]
parts: list[Part]
metadata: dict[str, Any] | None = None
index: int = 0
append: bool | None = None
@ -138,23 +89,15 @@ class Artifact(BaseModel):
class Task(BaseModel):
"""
Represents a task in the A2A protocol.
"""
id: str
sessionId: str | None = None
status: TaskStatus
artifacts: List[Artifact] | None = None
history: List[Message] | None = None
artifacts: list[Artifact] | None = None
history: list[Message] | None = None
metadata: dict[str, Any] | None = None
class TaskStatusUpdateEvent(BaseModel):
"""
Represents a task status update event.
"""
id: str
status: TaskStatus
final: bool = False
@ -162,295 +105,234 @@ class TaskStatusUpdateEvent(BaseModel):
class TaskArtifactUpdateEvent(BaseModel):
"""
Represents a task artifact update event.
"""
id: str
artifact: Artifact
metadata: dict[str, Any] | None = None
class AuthenticationInfo(BaseModel):
"""
Represents authentication information for push notifications.
"""
model_config = ConfigDict(extra="allow")
schemes: List[str]
schemes: list[str]
credentials: str | None = None
class PushNotificationConfig(BaseModel):
"""
Represents push notification configuration.
"""
url: str
token: str | None = None
authentication: AuthenticationInfo | None = None
class TaskIdParams(BaseModel):
"""
Represents parameters for identifying a task.
"""
id: str
metadata: dict[str, Any] | None = None
class TaskQueryParams(TaskIdParams):
"""
Represents parameters for querying a task.
"""
historyLength: int | None = None
class TaskSendParams(BaseModel):
"""
Represents parameters for sending a task.
"""
id: str
sessionId: str = Field(default_factory=lambda: uuid4().hex)
message: Message
acceptedOutputModes: Optional[List[str]] = None
acceptedOutputModes: list[str] | None = None
pushNotification: PushNotificationConfig | None = None
historyLength: int | None = None
metadata: dict[str, Any] | None = None
agentId: str = ""
class TaskPushNotificationConfig(BaseModel):
"""
Represents push notification configuration for a task.
"""
id: str
pushNotificationConfig: PushNotificationConfig
# RPC Messages
## RPC Messages
class JSONRPCMessage(BaseModel):
"""
Base class for JSON-RPC messages.
"""
jsonrpc: Literal["2.0"] = "2.0"
id: int | str | None = Field(default_factory=lambda: uuid4().hex)
class JSONRPCRequest(JSONRPCMessage):
"""
Represents a JSON-RPC request.
"""
method: str
params: dict[str, Any] | None = None
class JSONRPCError(BaseModel):
"""
Represents a JSON-RPC error.
"""
code: int
message: str
data: Any | None = None
class JSONRPCResponse(JSONRPCMessage):
"""
Represents a JSON-RPC response.
"""
result: Any | None = None
error: JSONRPCError | None = None
class SendTaskRequest(JSONRPCRequest):
"""
Represents a request to send a task.
"""
method: Literal["tasks/send"] = "tasks/send"
params: TaskSendParams
class SendTaskResponse(JSONRPCResponse):
"""
Represents a response to a send task request.
"""
result: Task | None = None
class SendTaskStreamingRequest(JSONRPCRequest):
"""
Represents a request to send a task with streaming.
"""
method: Literal["tasks/sendSubscribe"] = "tasks/sendSubscribe"
params: TaskSendParams
class SendTaskStreamingResponse(JSONRPCResponse):
"""
Represents a streaming response to a send task request.
"""
result: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None
class GetTaskRequest(JSONRPCRequest):
"""
Represents a request to get task information.
"""
method: Literal["tasks/get"] = "tasks/get"
params: TaskQueryParams
class GetTaskResponse(JSONRPCResponse):
"""
Represents a response to a get task request.
"""
result: Task | None = None
class CancelTaskRequest(JSONRPCRequest):
"""
Represents a request to cancel a task.
"""
method: Literal["tasks/cancel",] = "tasks/cancel"
params: TaskIdParams
class CancelTaskResponse(JSONRPCResponse):
"""
Represents a response to a cancel task request.
"""
result: Task | None = None
class SetTaskPushNotificationRequest(JSONRPCRequest):
"""
Represents a request to set push notification for a task.
"""
method: Literal["tasks/pushNotification/set",] = "tasks/pushNotification/set"
params: TaskPushNotificationConfig
class SetTaskPushNotificationResponse(JSONRPCResponse):
"""
Represents a response to a set push notification request.
"""
result: TaskPushNotificationConfig | None = None
class GetTaskPushNotificationRequest(JSONRPCRequest):
"""
Represents a request to get push notification configuration for a task.
"""
method: Literal["tasks/pushNotification/get",] = "tasks/pushNotification/get"
params: TaskIdParams
class GetTaskPushNotificationResponse(JSONRPCResponse):
"""
Represents a response to a get push notification request.
"""
result: TaskPushNotificationConfig | None = None
class TaskResubscriptionRequest(JSONRPCRequest):
"""
Represents a request to resubscribe to a task.
"""
method: Literal["tasks/resubscribe",] = "tasks/resubscribe"
params: TaskIdParams
# TypeAdapter for discriminating A2A requests by method
A2ARequest = TypeAdapter(
Annotated[
Union[
SendTaskRequest,
GetTaskRequest,
CancelTaskRequest,
SetTaskPushNotificationRequest,
GetTaskPushNotificationRequest,
TaskResubscriptionRequest,
SendTaskStreamingRequest,
],
SendTaskRequest
| GetTaskRequest
| CancelTaskRequest
| SetTaskPushNotificationRequest
| GetTaskPushNotificationRequest
| TaskResubscriptionRequest
| SendTaskStreamingRequest,
Field(discriminator="method"),
]
)
## Error types
# Agent Card schemas
class JSONParseError(JSONRPCError):
code: int = -32700
message: str = "Invalid JSON payload"
data: Any | None = None
class InvalidRequestError(JSONRPCError):
code: int = -32600
message: str = "Request payload validation error"
data: Any | None = None
class MethodNotFoundError(JSONRPCError):
code: int = -32601
message: str = "Method not found"
data: None = None
class InvalidParamsError(JSONRPCError):
code: int = -32602
message: str = "Invalid parameters"
data: Any | None = None
class InternalError(JSONRPCError):
code: int = -32603
message: str = "Internal error"
data: Any | None = None
class TaskNotFoundError(JSONRPCError):
code: int = -32001
message: str = "Task not found"
data: None = None
class TaskNotCancelableError(JSONRPCError):
code: int = -32002
message: str = "Task cannot be canceled"
data: None = None
class PushNotificationNotSupportedError(JSONRPCError):
code: int = -32003
message: str = "Push Notification is not supported"
data: None = None
class UnsupportedOperationError(JSONRPCError):
code: int = -32004
message: str = "This operation is not supported"
data: None = None
class ContentTypeNotSupportedError(JSONRPCError):
code: int = -32005
message: str = "Incompatible content types"
data: None = None
class AgentProvider(BaseModel):
"""
Represents the provider of an agent.
"""
organization: str
url: str | None = None
class AgentCapabilities(BaseModel):
"""
Represents the capabilities of an agent.
"""
streaming: bool = False
pushNotifications: bool = False
stateTransitionHistory: bool = False
class AgentAuthentication(BaseModel):
"""
Represents the authentication requirements for an agent.
"""
schemes: List[str]
schemes: list[str]
credentials: str | None = None
class AgentSkill(BaseModel):
"""
Represents a skill of an agent.
"""
id: str
name: str
description: str | None = None
tags: List[str] | None = None
examples: List[str] | None = None
inputModes: List[str] | None = None
outputModes: List[str] | None = None
tags: list[str] | None = None
examples: list[str] | None = None
inputModes: list[str] | None = None
outputModes: list[str] | None = None
class AgentCard(BaseModel):
"""
Represents an agent card in the A2A protocol.
"""
name: str
description: str | None = None
url: str
@ -459,6 +341,27 @@ class AgentCard(BaseModel):
documentationUrl: str | None = None
capabilities: AgentCapabilities
authentication: AgentAuthentication | None = None
defaultInputModes: List[str] = ["text"]
defaultOutputModes: List[str] = ["text"]
skills: List[AgentSkill]
defaultInputModes: list[str] = ["text"]
defaultOutputModes: list[str] = ["text"]
skills: list[AgentSkill]
class A2AClientError(Exception):
pass
class A2AClientHTTPError(A2AClientError):
def __init__(self, status_code: int, message: str):
self.status_code = status_code
self.message = message
super().__init__(f"HTTP Error {status_code}: {message}")
class A2AClientJSONError(A2AClientError):
def __init__(self, message: str):
self.message = message
super().__init__(f"JSON Error: {message}")
class MissingAPIKeyError(Exception):
"""Exception for missing API key."""

View File

@ -1,9 +1 @@
from .agent_runner import run_agent
from .redis_cache_service import RedisCacheService
from .a2a_task_manager_service import A2ATaskManager
from .a2a_server_service import A2AServer
from .a2a_integration_service import (
AgentRunnerAdapter,
StreamingServiceAdapter,
create_agent_card_from_agent,
)

View File

@ -7,7 +7,7 @@ import json
import asyncio
import time
from src.schemas.a2a.types import (
from src.schemas.a2a_types import (
GetTaskRequest,
SendTaskRequest,
Message,

View File

@ -1,507 +0,0 @@
"""
A2A Integration Service.
This service provides adapters to integrate existing services with the A2A protocol.
"""
import json
import logging
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional, AsyncIterable
from src.schemas.a2a import (
AgentCard,
AgentCapabilities,
AgentProvider,
Artifact,
Message,
TaskArtifactUpdateEvent,
TaskStatus,
TaskStatusUpdateEvent,
TextPart,
)
logger = logging.getLogger(__name__)
class AgentRunnerAdapter:
"""
Adapter for integrating the existing agent runner with the A2A protocol.
"""
def __init__(
self,
agent_runner_func,
session_service=None,
artifacts_service=None,
memory_service=None,
db=None,
):
"""
Initialize the adapter.
Args:
agent_runner_func: The agent runner function (e.g., run_agent)
session_service: Session service for message history
artifacts_service: Artifacts service for artifact history
memory_service: Memory service for agent memory
db: Database session
"""
self.agent_runner_func = agent_runner_func
self.session_service = session_service
self.artifacts_service = artifacts_service
self.memory_service = memory_service
self.db = db
async def get_supported_modes(self) -> List[str]:
"""
Get the supported output modes for the agent.
Returns:
List of supported output modes
"""
# Default modes, can be extended based on agent configuration
return ["text", "application/json"]
async def run_agent(
self,
agent_id: str,
message: str,
session_id: Optional[str] = None,
task_id: Optional[str] = None,
db=None,
) -> Dict[str, Any]:
"""
Run the agent with the given message.
Args:
agent_id: ID of the agent to run
message: User message to process
session_id: Optional session ID for conversation context
task_id: Optional task ID for tracking
db: Database session
Returns:
Dictionary with the agent's response
"""
try:
# Use the existing agent runner function
# Usar o session_id fornecido, ou gerar um novo
session_id = session_id or str(uuid.uuid4())
task_id = task_id or str(uuid.uuid4())
# Use the provided db or fallback to self.db
db_session = db if db is not None else self.db
response_text = await self.agent_runner_func(
agent_id=agent_id,
contact_id=task_id,
message=message,
session_service=self.session_service,
artifacts_service=self.artifacts_service,
memory_service=self.memory_service,
db=db_session,
session_id=session_id,
)
# Format the response to include both the A2A-compatible message
# and the artifact format to match the Google A2A implementation
# Nota: O formato dos artifacts é simplificado para compatibilidade com Google A2A
message_obj = {
"role": "agent",
"parts": [{"type": "text", "text": response_text}],
}
# Formato de artefato compatível com Google A2A
artifact_obj = {
"parts": [{"type": "text", "text": response_text}],
"index": 0,
}
return {
"status": "success",
"content": response_text,
"session_id": session_id,
"task_id": task_id,
"timestamp": datetime.now().isoformat(),
"message": message_obj,
"artifact": artifact_obj,
}
except Exception as e:
logger.error(f"[AGENT-RUNNER] Error running agent: {e}", exc_info=True)
return {
"status": "error",
"error": str(e),
"session_id": session_id,
"task_id": task_id,
"timestamp": datetime.now().isoformat(),
"message": {
"role": "agent",
"parts": [{"type": "text", "text": f"Error: {str(e)}"}],
},
}
async def cancel_task(self, task_id: str) -> bool:
"""
Cancel a running task.
Args:
task_id: ID of the task to cancel
Returns:
True if successfully canceled, False otherwise
"""
# Currently, the agent runner doesn't support cancellation
# This is a placeholder for future implementation
logger.warning(f"Task cancellation not implemented for task {task_id}")
return False
class StreamingServiceAdapter:
"""
Adapter for integrating the existing streaming service with the A2A protocol.
"""
def __init__(self, streaming_service):
"""
Initialize the adapter.
Args:
streaming_service: The streaming service instance
"""
self.streaming_service = streaming_service
async def stream_agent_response(
self,
agent_id: str,
message: str,
api_key: str,
session_id: Optional[str] = None,
task_id: Optional[str] = None,
db=None,
) -> AsyncIterable[str]:
"""
Stream the agent's response as A2A events.
Args:
agent_id: ID of the agent
message: User message to process
api_key: API key for authentication
session_id: Optional session ID for conversation context
task_id: Optional task ID for tracking
db: Database session
Yields:
A2A event objects as JSON strings for SSE (Server-Sent Events)
"""
task_id = task_id or str(uuid.uuid4())
logger.info(f"Starting streaming response for task {task_id}")
# Set working status event
working_status = TaskStatus(
state="working",
timestamp=datetime.now(),
message=Message(
role="agent", parts=[TextPart(text="Processing your request...")]
),
)
status_event = TaskStatusUpdateEvent(
id=task_id, status=working_status, final=False
)
yield json.dumps(status_event.model_dump())
content_buffer = ""
final_sent = False
has_error = False
# Stream from the existing streaming service
try:
logger.info(f"Setting up streaming for agent {agent_id}, task {task_id}")
# To streaming, we use task_id as contact_id
contact_id = task_id
last_event_time = datetime.now()
heartbeat_interval = 20
async for event in self.streaming_service.send_task_streaming(
agent_id=agent_id,
api_key=api_key,
message=message,
contact_id=contact_id,
session_id=session_id,
db=db,
):
last_event_time = datetime.now()
# Process the streaming event format
event_data = event.get("data", "{}")
try:
logger.info(f"Processing event data: {event_data[:100]}...")
data = json.loads(event_data)
# Extract content
if "delta" in data and data["delta"].get("content"):
content = data["delta"]["content"]
content_buffer += content
logger.info(f"Received content chunk: {content[:50]}...")
# Create artifact update event
artifact = Artifact(
name="response",
parts=[TextPart(text=content)],
index=0,
append=True,
lastChunk=False,
)
artifact_event = TaskArtifactUpdateEvent(
id=task_id, artifact=artifact
)
yield json.dumps(artifact_event.model_dump())
# Check if final event
if data.get("done", False) and not final_sent:
logger.info(f"Received final event for task {task_id}")
# Create completed status event
completed_status = TaskStatus(
state="completed",
timestamp=datetime.now(),
message=Message(
role="agent",
parts=[
TextPart(text=content_buffer or "Task completed")
],
),
)
# Final artifact with full content
final_artifact = Artifact(
name="response",
parts=[TextPart(text=content_buffer)],
index=0,
append=False,
lastChunk=True,
)
# Send the final artifact
final_artifact_event = TaskArtifactUpdateEvent(
id=task_id, artifact=final_artifact
)
yield json.dumps(final_artifact_event.model_dump())
# Send the completed status
final_status_event = TaskStatusUpdateEvent(
id=task_id,
status=completed_status,
final=True,
)
yield json.dumps(final_status_event.model_dump())
final_sent = True
except json.JSONDecodeError as e:
logger.warning(
f"Received non-JSON event data: {e}. Data: {event_data[:100]}..."
)
# Handle non-JSON events - simply add to buffer as text
if isinstance(event_data, str):
content_buffer += event_data
# Create artifact update event
artifact = Artifact(
name="response",
parts=[TextPart(text=event_data)],
index=0,
append=True,
lastChunk=False,
)
artifact_event = TaskArtifactUpdateEvent(
id=task_id, artifact=artifact
)
yield json.dumps(artifact_event.model_dump())
elif isinstance(event_data, dict):
# Try to extract text from the dictionary
text_value = str(event_data)
content_buffer += text_value
artifact = Artifact(
name="response",
parts=[TextPart(text=text_value)],
index=0,
append=True,
lastChunk=False,
)
artifact_event = TaskArtifactUpdateEvent(
id=task_id, artifact=artifact
)
yield json.dumps(artifact_event.model_dump())
# Send heartbeat/keep-alive to keep the SSE connection open
now = datetime.now()
if (now - last_event_time).total_seconds() > heartbeat_interval:
logger.info(f"Sending heartbeat for task {task_id}")
# Sending keep-alive event as a "working" status event
working_heartbeat = TaskStatus(
state="working",
timestamp=now,
message=Message(
role="agent", parts=[TextPart(text="Still processing...")]
),
)
heartbeat_event = TaskStatusUpdateEvent(
id=task_id, status=working_heartbeat, final=False
)
yield json.dumps(heartbeat_event.model_dump())
last_event_time = now
# Ensure we send a final event if not already sent
if not final_sent:
logger.info(
f"Stream completed for task {task_id}, sending final status"
)
# Create completed status event
completed_status = TaskStatus(
state="completed",
timestamp=datetime.now(),
message=Message(
role="agent",
parts=[TextPart(text=content_buffer or "Task completed")],
),
)
# Send the completed status
final_event = TaskStatusUpdateEvent(
id=task_id, status=completed_status, final=True
)
yield json.dumps(final_event.model_dump())
except Exception as e:
has_error = True
logger.error(f"Error in streaming for task {task_id}: {e}", exc_info=True)
# Create failed status event
failed_status = TaskStatus(
state="failed",
timestamp=datetime.now(),
message=Message(
role="agent",
parts=[
TextPart(
text=f"Error during streaming: {str(e)}. Partial response: {content_buffer[:200] if content_buffer else 'No content received'}"
)
],
),
)
error_event = TaskStatusUpdateEvent(
id=task_id, status=failed_status, final=True
)
yield json.dumps(error_event.model_dump())
finally:
# Ensure we send a final event to properly close the connection
if not final_sent and not has_error:
logger.info(f"Stream finalizing for task {task_id} via finally block")
try:
# Create completed status event
completed_status = TaskStatus(
state="completed",
timestamp=datetime.now(),
message=Message(
role="agent",
parts=[
TextPart(
text=content_buffer or "Task completed (forced end)"
)
],
),
)
# Send the completed status
final_event = TaskStatusUpdateEvent(
id=task_id, status=completed_status, final=True
)
yield json.dumps(final_event.model_dump())
except Exception as final_error:
logger.error(
f"Error sending final event in finally block: {final_error}"
)
logger.info(f"Streaming completed for task {task_id}")
def create_agent_card_from_agent(agent, db) -> AgentCard:
"""
Create an A2A agent card from an agent model.
Args:
agent: The agent model from the database
db: Database session
Returns:
A2A AgentCard object
"""
import os
from src.api.agent_routes import format_agent_tools
import asyncio
# Extract agent configuration
agent_config = agent.config
has_streaming = True # Assuming streaming is always supported
has_push = True # Assuming push notifications are supported
# Format tools as skills
try:
# We use a different approach to handle the asynchronous function
mcp_servers = agent_config.get("mcp_servers", [])
# We create a new thread to execute the asynchronous function
import concurrent.futures
def run_async(coro):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(coro)
loop.close()
return result
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_async, format_agent_tools(mcp_servers, db))
skills = future.result()
except Exception as e:
logger.error(f"Error formatting agent tools: {e}")
skills = []
# Create agent card
return AgentCard(
name=agent.name,
description=agent.description,
url=f"{os.getenv('API_URL', '')}/api/v1/a2a/{agent.id}",
provider=AgentProvider(
organization=os.getenv("ORGANIZATION_NAME", ""),
url=os.getenv("ORGANIZATION_URL", ""),
),
version=os.getenv("API_VERSION", "1.0.0"),
capabilities=AgentCapabilities(
streaming=has_streaming,
pushNotifications=has_push,
stateTransitionHistory=True,
),
authentication={
"schemes": ["apiKey"],
"credentials": "x-api-key",
},
defaultInputModes=["text", "application/json"],
defaultOutputModes=["text", "application/json"],
skills=skills,
)

View File

@ -1,708 +0,0 @@
"""
Server A2A and task manager for the A2A protocol.
This module implements a JSON-RPC compatible server for the A2A protocol,
that manages agent tasks, streaming events and push notifications.
"""
import json
import logging
from datetime import datetime
from typing import (
Any,
Dict,
List,
Optional,
AsyncGenerator,
Union,
AsyncIterable,
)
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse, Response
from sqlalchemy.orm import Session
from src.schemas.a2a.types import A2ARequest
from src.services.a2a_integration_service import (
AgentRunnerAdapter,
StreamingServiceAdapter,
)
from src.services.session_service import get_session_events
from src.services.redis_cache_service import RedisCacheService
from src.schemas.a2a.types import (
SendTaskRequest,
SendTaskStreamingRequest,
GetTaskRequest,
CancelTaskRequest,
SetTaskPushNotificationRequest,
GetTaskPushNotificationRequest,
TaskResubscriptionRequest,
)
logger = logging.getLogger(__name__)
class A2ATaskManager:
"""
Task manager for the A2A protocol.
This class manages the lifecycle of tasks, including:
- Task execution
- Streaming of events
- Push notifications
- Status querying
- Cancellation
"""
def __init__(
self,
redis_cache: RedisCacheService,
agent_runner: AgentRunnerAdapter,
streaming_service: StreamingServiceAdapter,
push_notification_service: Any = None,
):
"""
Initialize the task manager.
Args:
redis_cache: Cache service for storing task data
agent_runner: Adapter for agent execution
streaming_service: Adapter for event streaming
push_notification_service: Service for sending push notifications
"""
self.cache = redis_cache
self.agent_runner = agent_runner
self.streaming_service = streaming_service
self.push_notification_service = push_notification_service
self._running_tasks = {}
async def on_send_task(
self,
task_id: str,
agent_id: str,
message: Dict[str, Any],
session_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
input_mode: str = "text",
output_modes: List[str] = ["text"],
db: Optional[Session] = None,
) -> Dict[str, Any]:
"""
Process a request to send a task.
Args:
task_id: Task ID
agent_id: Agent ID
message: User message
session_id: Session ID (optional)
metadata: Additional metadata (optional)
input_mode: Input mode (text, JSON, etc.)
output_modes: Supported output modes
db: Database session
Returns:
Response with task result
"""
if not session_id:
session_id = f"{task_id}_{agent_id}"
if not metadata:
metadata = {}
# Update status to "submitted"
task_data = {
"id": task_id,
"sessionId": session_id,
"status": {
"state": "submitted",
"timestamp": datetime.now().isoformat(),
"message": None,
"error": None,
},
"artifacts": [],
"history": [],
"metadata": metadata,
}
# Store initial task data
await self.cache.set(f"task:{task_id}", task_data)
# Check for push notification configurations
push_config = await self.cache.get(f"task:{task_id}:push")
if push_config and self.push_notification_service:
# Send initial notification
await self.push_notification_service.send_notification(
url=push_config["url"],
task_id=task_id,
state="submitted",
headers=push_config.get("headers", {}),
)
try:
# Update status to "running"
task_data["status"].update(
{"state": "running", "timestamp": datetime.now().isoformat()}
)
await self.cache.set(f"task:{task_id}", task_data)
# Notify "running" state
if push_config and self.push_notification_service:
await self.push_notification_service.send_notification(
url=push_config["url"],
task_id=task_id,
state="running",
headers=push_config.get("headers", {}),
)
# Extract user message
user_message = None
try:
user_message = message["parts"][0]["text"]
except (KeyError, IndexError):
user_message = ""
# Execute the agent
response = await self.agent_runner.run_agent(
agent_id=agent_id,
task_id=task_id,
message=user_message,
session_id=session_id,
db=db,
)
# Check if the response is a dictionary (error) or a string (success)
if isinstance(response, dict) and response.get("status") == "error":
# Error response
final_response = f"Error: {response.get('error', 'Unknown error')}"
# Update status to "failed"
task_data["status"].update(
{
"state": "failed",
"timestamp": datetime.now().isoformat(),
"error": {
"code": "AGENT_EXECUTION_ERROR",
"message": response.get("error", "Unknown error"),
},
"message": {
"role": "system",
"parts": [{"type": "text", "text": final_response}],
},
}
)
# Notify "failed" state
if push_config and self.push_notification_service:
await self.push_notification_service.send_notification(
url=push_config["url"],
task_id=task_id,
state="failed",
message={
"role": "system",
"parts": [{"type": "text", "text": final_response}],
},
headers=push_config.get("headers", {}),
)
else:
# Success response
final_response = (
response.get("content") if isinstance(response, dict) else response
)
# Update status to "completed"
task_data["status"].update(
{
"state": "completed",
"timestamp": datetime.now().isoformat(),
"message": {
"role": "agent",
"parts": [{"type": "text", "text": final_response}],
},
}
)
# Add artifacts
if final_response:
task_data["artifacts"].append(
{
"type": "text",
"content": final_response,
"metadata": {
"generated_at": datetime.now().isoformat(),
"content_type": "text/plain",
},
}
)
# Add history of messages
history_length = metadata.get("historyLength", 50)
try:
history_messages = get_session_events(
self.agent_runner.session_service, session_id
)
history_messages = history_messages[-history_length:]
formatted_history = []
for event in history_messages:
if event.content and event.content.parts:
role = (
"agent"
if event.content.role == "model"
else event.content.role
)
formatted_history.append(
{
"role": role,
"parts": [
{"type": "text", "text": part.text}
for part in event.content.parts
if part.text
],
}
)
task_data["history"] = formatted_history
except Exception as e:
logger.error(f"Error processing history: {str(e)}")
# Notify "completed" state
if push_config and self.push_notification_service:
await self.push_notification_service.send_notification(
url=push_config["url"],
task_id=task_id,
state="completed",
message={
"role": "agent",
"parts": [{"type": "text", "text": final_response}],
},
headers=push_config.get("headers", {}),
)
except Exception as e:
logger.error(f"Error executing task {task_id}: {str(e)}")
# Update status to "failed"
task_data["status"].update(
{
"state": "failed",
"timestamp": datetime.now().isoformat(),
"error": {"code": "AGENT_EXECUTION_ERROR", "message": str(e)},
}
)
# Notify "failed" state
if push_config and self.push_notification_service:
await self.push_notification_service.send_notification(
url=push_config["url"],
task_id=task_id,
state="failed",
message={
"role": "system",
"parts": [{"type": "text", "text": str(e)}],
},
headers=push_config.get("headers", {}),
)
# Store final result
await self.cache.set(f"task:{task_id}", task_data)
return task_data
async def on_send_task_subscribe(
self,
task_id: str,
agent_id: str,
message: Dict[str, Any],
session_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
input_mode: str = "text",
output_modes: List[str] = ["text"],
db: Optional[Session] = None,
) -> AsyncGenerator[str, None]:
"""
Process a request to send a task with streaming.
Args:
task_id: Task ID
agent_id: Agent ID
message: User message
session_id: Session ID (optional)
metadata: Additional metadata (optional)
input_mode: Input mode (text, JSON, etc.)
output_modes: Supported output modes
db: Database session
Yields:
Streaming events in SSE (Server-Sent Events) format
"""
if not session_id:
session_id = f"{task_id}_{agent_id}"
if not metadata:
metadata = {}
# Extract user message
user_message = ""
try:
user_message = message["parts"][0]["text"]
except (KeyError, IndexError):
pass
# Generate streaming events
async for event in self.streaming_service.stream_response(
agent_id=agent_id,
task_id=task_id,
message=user_message,
session_id=session_id,
db=db,
):
yield event
async def on_get_task(self, task_id: str) -> Dict[str, Any]:
"""
Query the status of a task by ID.
Args:
task_id: Task ID
Returns:
Current task status
Raises:
Exception: If the task is not found
"""
task_data = await self.cache.get(f"task:{task_id}")
if not task_data:
raise Exception(f"Task {task_id} not found")
return task_data
async def on_cancel_task(self, task_id: str) -> Dict[str, Any]:
"""
Cancel a running task.
Args:
task_id: Task ID to be cancelled
Returns:
Task status after cancellation
Raises:
Exception: If the task is not found or cannot be cancelled
"""
task_data = await self.cache.get(f"task:{task_id}")
if not task_data:
raise Exception(f"Task {task_id} not found")
# Check if the task is in a state that can be cancelled
current_state = task_data["status"]["state"]
if current_state not in ["submitted", "running"]:
raise Exception(f"Cannot cancel task in {current_state} state")
# Cancel the task in the runner if it is running
running_task = self._running_tasks.get(task_id)
if running_task:
# Try to cancel the running task
if hasattr(running_task, "cancel"):
running_task.cancel()
# Update status to "cancelled"
task_data["status"].update(
{
"state": "cancelled",
"timestamp": datetime.now().isoformat(),
"message": {
"role": "system",
"parts": [{"type": "text", "text": "Task cancelled by user"}],
},
}
)
# Update cache
await self.cache.set(f"task:{task_id}", task_data)
# Send push notification if configured
push_config = await self.cache.get(f"task:{task_id}:push")
if push_config and self.push_notification_service:
await self.push_notification_service.send_notification(
url=push_config["url"],
task_id=task_id,
state="cancelled",
message={
"role": "system",
"parts": [{"type": "text", "text": "Task cancelled by user"}],
},
headers=push_config.get("headers", {}),
)
return task_data
async def on_set_task_push_notification(
self, task_id: str, notification_config: Dict[str, Any]
) -> Dict[str, Any]:
"""
Configure push notifications for a task.
Args:
task_id: Task ID
notification_config: Notification configuration (URL and headers)
Returns:
Updated configuration
"""
# Validate configuration
url = notification_config.get("url")
if not url:
raise ValueError("Push notification URL is required")
headers = notification_config.get("headers", {})
# Store configuration
config = {"url": url, "headers": headers}
await self.cache.set(f"task:{task_id}:push", config)
return config
async def on_get_task_push_notification(self, task_id: str) -> Dict[str, Any]:
"""
Get the push notification configuration for a task.
Args:
task_id: Task ID
Returns:
Push notification configuration
Raises:
Exception: If there is no configuration for the task
"""
config = await self.cache.get(f"task:{task_id}:push")
if not config:
raise Exception(f"No push notification configuration for task {task_id}")
return config
class A2AServer:
"""
A2A server compatible with JSON-RPC 2.0.
This class processes JSON-RPC requests and forwards them to
the appropriate handlers in the A2ATaskManager.
"""
def __init__(self, task_manager: A2ATaskManager, agent_card=None):
"""
Initialize the A2A server.
Args:
task_manager: Task manager
agent_card: Agent card information
"""
self.task_manager = task_manager
self.agent_card = agent_card
async def process_request(
self,
request: Request,
agent_id: Optional[str] = None,
db: Optional[Session] = None,
) -> Union[Response, JSONResponse, StreamingResponse]:
"""
Process a JSON-RPC request.
Args:
request: HTTP request
agent_id: Optional agent ID to inject into the request
db: Database session
Returns:
Appropriate response (JSON or Streaming)
"""
try:
# Try to parse the JSON payload
try:
logger.info("Starting JSON-RPC request processing")
body = await request.json()
logger.info(f"Received JSON data: {json.dumps(body)}")
method = body.get("method", "unknown")
# Validate the request using the A2A validator
json_rpc_request = A2ARequest.validate_python(body)
original_db = self.task_manager.db
try:
# Set the db temporarily
if db is not None:
self.task_manager.db = db
# Process the request
if isinstance(json_rpc_request, SendTaskRequest):
json_rpc_request.params.agentId = agent_id
result = await self.task_manager.on_send_task(json_rpc_request)
elif isinstance(json_rpc_request, SendTaskStreamingRequest):
json_rpc_request.params.agentId = agent_id
result = await self.task_manager.on_send_task_subscribe(
json_rpc_request
)
elif isinstance(json_rpc_request, GetTaskRequest):
result = await self.task_manager.on_get_task(json_rpc_request)
elif isinstance(json_rpc_request, CancelTaskRequest):
result = await self.task_manager.on_cancel_task(
json_rpc_request
)
elif isinstance(json_rpc_request, SetTaskPushNotificationRequest):
result = await self.task_manager.on_set_task_push_notification(
json_rpc_request
)
elif isinstance(json_rpc_request, GetTaskPushNotificationRequest):
result = await self.task_manager.on_get_task_push_notification(
json_rpc_request
)
elif isinstance(json_rpc_request, TaskResubscriptionRequest):
result = await self.task_manager.on_resubscribe_to_task(
json_rpc_request
)
else:
logger.warning(
f"[SERVER] Request type not supported: {type(json_rpc_request)}"
)
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"id": body.get("id"),
"error": {
"code": -32601,
"message": "Method not found",
"data": {
"detail": f"Method not supported: {method}"
},
},
},
)
finally:
# Restore the original db
self.task_manager.db = original_db
# Create appropriate response
return self._create_response(result)
except json.JSONDecodeError as e:
# Error parsing JSON
logger.error(f"Error parsing JSON request: {str(e)}")
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"id": None,
"error": {
"code": -32700,
"message": "Parse error",
"data": {"detail": str(e)},
},
},
)
except Exception as e:
# Other validation errors
logger.error(f"Error validating request: {str(e)}")
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"id": body.get("id") if "body" in locals() else None,
"error": {
"code": -32600,
"message": "Invalid Request",
"data": {"detail": str(e)},
},
},
)
except Exception as e:
logger.error(f"Error processing JSON-RPC request: {str(e)}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"jsonrpc": "2.0",
"id": None,
"error": {
"code": -32603,
"message": "Internal error",
"data": {"detail": str(e)},
},
},
)
def _create_response(self, result: Any) -> Union[JSONResponse, StreamingResponse]:
"""
Create appropriate response based on result type.
Args:
result: Result from task manager
Returns:
JSON or streaming response
"""
if isinstance(result, AsyncIterable):
# Result in streaming (SSE)
async def event_generator():
async for item in result:
if hasattr(item, "model_dump_json"):
yield {"data": item.model_dump_json(exclude_none=True)}
else:
yield {"data": json.dumps(item)}
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
elif hasattr(result, "model_dump"):
# Result is a Pydantic object
return JSONResponse(result.model_dump(exclude_none=True))
else:
# Result is a dictionary or other simple type
return JSONResponse(result)
async def get_agent_card(
self, request: Request, db: Optional[Session] = None
) -> JSONResponse:
"""
Get the agent card.
Args:
request: HTTP request
db: Database session
Returns:
Agent card as JSON
"""
if not self.agent_card:
logger.error("Agent card not configured")
return JSONResponse(
status_code=404, content={"error": "Agent card not configured"}
)
# If there is db, set it temporarily in the task_manager
if db and hasattr(self.task_manager, "db"):
original_db = self.task_manager.db
try:
self.task_manager.db = db
# If it's a Pydantic object, convert to dictionary
if hasattr(self.agent_card, "model_dump"):
return JSONResponse(self.agent_card.model_dump(exclude_none=True))
else:
return JSONResponse(self.agent_card)
finally:
# Restore the original db
self.task_manager.db = original_db
else:
# If it's a Pydantic object, convert to dictionary
if hasattr(self.agent_card, "model_dump"):
return JSONResponse(self.agent_card.model_dump(exclude_none=True))
else:
return JSONResponse(self.agent_card)

View File

@ -0,0 +1,644 @@
import json
import logging
import asyncio
from collections.abc import AsyncIterable
from typing import Any, Dict, Optional, Union, List
from uuid import UUID
from sqlalchemy.orm import Session
from src.config.settings import settings
from src.services.agent_service import (
get_agent,
create_agent,
update_agent,
delete_agent,
get_agents_by_client,
)
from src.services.mcp_server_service import get_mcp_server
from src.services.session_service import (
get_sessions_by_client,
get_sessions_by_agent,
get_session_by_id,
delete_session,
get_session_events,
)
from src.services.agent_runner import run_agent, run_agent_stream
from src.services.service_providers import (
session_service,
artifacts_service,
memory_service,
)
from src.models.models import Agent
from src.schemas.a2a_types import (
A2ARequest,
GetTaskRequest,
SendTaskRequest,
SendTaskResponse,
SendTaskStreamingRequest,
SendTaskStreamingResponse,
CancelTaskRequest,
SetTaskPushNotificationRequest,
GetTaskPushNotificationRequest,
TaskResubscriptionRequest,
JSONRPCResponse,
TaskStatusUpdateEvent,
TaskArtifactUpdateEvent,
Task,
TaskSendParams,
InternalError,
Message,
Artifact,
TaskStatus,
TaskState,
AgentCard,
AgentCapabilities,
AgentSkill,
)
logger = logging.getLogger(__name__)
class A2ATaskManager:
"""Task manager for the A2A protocol."""
def __init__(self, db: Session):
self.db = db
self.tasks: Dict[str, Task] = {}
self.lock = asyncio.Lock()
async def upsert_task(self, task_params: TaskSendParams) -> Task:
"""Creates or updates a task in the store."""
async with self.lock:
task = self.tasks.get(task_params.id)
if task is None:
# Create new task with initial history
task = Task(
id=task_params.id,
sessionId=task_params.sessionId,
status=TaskStatus(state=TaskState.SUBMITTED),
history=[task_params.message],
artifacts=[],
)
self.tasks[task_params.id] = task
else:
# Add message to existing history
if task.history is None:
task.history = []
task.history.append(task_params.message)
return task
async def on_get_task(self, request: GetTaskRequest) -> JSONRPCResponse:
"""Handles requests to get task details."""
try:
task_id = request.params.id
history_length = request.params.historyLength
async with self.lock:
if task_id not in self.tasks:
return JSONRPCResponse(
id=request.id,
error=InternalError(message=f"Task {task_id} not found"),
)
# Get the task and limit the history as requested
task = self.tasks[task_id]
task_result = self.append_task_history(task, history_length)
return SendTaskResponse(id=request.id, result=task_result)
except Exception as e:
logger.error(f"Error getting task: {e}")
return JSONRPCResponse(
id=request.id,
error=InternalError(message=f"Error getting task: {str(e)}"),
)
async def on_send_task(
self, request: SendTaskRequest, agent_id: UUID
) -> JSONRPCResponse:
"""Handles requests to send a task for processing."""
try:
agent = get_agent(self.db, agent_id)
if not agent:
return JSONRPCResponse(
id=request.id,
error=InternalError(message=f"Agent {agent_id} not found"),
)
await self.upsert_task(request.params)
return await self._process_task(request, agent)
except Exception as e:
logger.error(f"Error sending task: {e}")
return JSONRPCResponse(
id=request.id,
error=InternalError(message=f"Error sending task: {str(e)}"),
)
async def on_send_task_subscribe(
self, request: SendTaskStreamingRequest, agent_id: UUID
) -> AsyncIterable[SendTaskStreamingResponse]:
"""Handles requests to send a task and subscribe to updates."""
try:
agent = get_agent(self.db, agent_id)
if not agent:
yield JSONRPCResponse(
id=request.id,
error=InternalError(message=f"Agent {agent_id} not found"),
)
return
await self.upsert_task(request.params)
async for response in self._stream_task_process(request, agent):
yield response
except Exception as e:
logger.error(f"Error processing streaming task: {e}")
yield JSONRPCResponse(
id=request.id,
error=InternalError(
message=f"Error processing streaming task: {str(e)}"
),
)
async def on_cancel_task(self, request: CancelTaskRequest) -> JSONRPCResponse:
"""Handles requests to cancel a task."""
try:
task_id = request.params.id
async with self.lock:
if task_id in self.tasks:
task = self.tasks[task_id]
task.status = TaskStatus(state=TaskState.CANCELED)
return JSONRPCResponse(id=request.id, result=True)
return JSONRPCResponse(
id=request.id,
error=InternalError(message=f"Task {task_id} not found"),
)
except Exception as e:
logger.error(f"Error canceling task: {e}")
return JSONRPCResponse(
id=request.id,
error=InternalError(message=f"Error canceling task: {str(e)}"),
)
async def on_set_task_push_notification(
self, request: SetTaskPushNotificationRequest
) -> JSONRPCResponse:
"""Handles requests to configure push notifications for a task."""
return JSONRPCResponse(id=request.id, result=True)
async def on_get_task_push_notification(
self, request: GetTaskPushNotificationRequest
) -> JSONRPCResponse:
"""Handles requests to get push notification settings for a task."""
return JSONRPCResponse(id=request.id, result={})
async def on_resubscribe_to_task(
self, request: TaskResubscriptionRequest
) -> AsyncIterable[SendTaskStreamingResponse]:
"""Handles requests to resubscribe to a task."""
task_id = request.params.id
try:
async with self.lock:
if task_id not in self.tasks:
yield SendTaskStreamingResponse(
id=request.id,
error=InternalError(message=f"Task {task_id} not found"),
)
return
task = self.tasks[task_id]
yield SendTaskStreamingResponse(
id=request.id,
result=TaskStatusUpdateEvent(
id=task_id,
status=task.status,
final=False,
),
)
if task.artifacts:
for artifact in task.artifacts:
yield SendTaskStreamingResponse(
id=request.id,
result=TaskArtifactUpdateEvent(id=task_id, artifact=artifact),
)
yield SendTaskStreamingResponse(
id=request.id,
result=TaskStatusUpdateEvent(
id=task_id,
status=TaskStatus(state=task.status.state),
final=True,
),
)
except Exception as e:
logger.error(f"Error resubscribing to task: {e}")
yield SendTaskStreamingResponse(
id=request.id,
error=InternalError(message=f"Error resubscribing to task: {str(e)}"),
)
async def _process_task(
self, request: SendTaskRequest, agent: Agent
) -> JSONRPCResponse:
"""Processes a task using the specified agent."""
task_params = request.params
query = self._extract_user_query(task_params)
try:
# Process the query with the agent
result = await self._run_agent(agent, query, task_params.sessionId)
# Create the response part
text_part = {"type": "text", "text": result}
parts = [text_part]
agent_message = Message(role="agent", parts=parts)
# Determine the task state
task_state = (
TaskState.INPUT_REQUIRED
if "MISSING_INFO:" in result
else TaskState.COMPLETED
)
# Update the task in the store
task = await self.update_store(
task_params.id,
TaskStatus(state=task_state, message=agent_message),
[Artifact(parts=parts, index=0)],
)
return SendTaskResponse(id=request.id, result=task)
except Exception as e:
logger.error(f"Error processing task: {e}")
return JSONRPCResponse(
id=request.id,
error=InternalError(message=f"Error processing task: {str(e)}"),
)
async def _stream_task_process(
self, request: SendTaskStreamingRequest, agent: Agent
) -> AsyncIterable[SendTaskStreamingResponse]:
"""Processes a task in streaming mode using the specified agent."""
task_params = request.params
query = self._extract_user_query(task_params)
try:
# Send initial processing status
processing_text_part = {
"type": "text",
"text": "Processing your request...",
}
processing_message = Message(role="agent", parts=[processing_text_part])
# Update the task with the processing message and inform the WORKING state
await self.update_store(
task_params.id,
TaskStatus(state=TaskState.WORKING, message=processing_message),
)
yield SendTaskStreamingResponse(
id=request.id,
result=TaskStatusUpdateEvent(
id=task_params.id,
status=TaskStatus(
state=TaskState.WORKING,
message=processing_message,
),
final=False,
),
)
# Collect the chunks of the agent's response
contact_id = task_params.sessionId
full_response = ""
# We use the same streaming function used in the WebSocket
async for chunk in run_agent_stream(
agent_id=str(agent.id),
contact_id=contact_id,
message=query,
session_service=session_service,
artifacts_service=artifacts_service,
memory_service=memory_service,
db=self.db,
):
# Send incremental progress updates
update_text_part = {"type": "text", "text": chunk}
update_message = Message(role="agent", parts=[update_text_part])
# Update the task with each intermediate message
await self.update_store(
task_params.id,
TaskStatus(state=TaskState.WORKING, message=update_message),
)
yield SendTaskStreamingResponse(
id=request.id,
result=TaskStatusUpdateEvent(
id=task_params.id,
status=TaskStatus(
state=TaskState.WORKING,
message=update_message,
),
final=False,
),
)
full_response += chunk
# Determine the task state
task_state = (
TaskState.INPUT_REQUIRED
if "MISSING_INFO:" in full_response
else TaskState.COMPLETED
)
# Create the final response part
final_text_part = {"type": "text", "text": full_response}
parts = [final_text_part]
final_message = Message(role="agent", parts=parts)
# Create the final artifact from the final response
final_artifact = Artifact(parts=parts, index=0)
# Update the task in the store with the final response
task = await self.update_store(
task_params.id,
TaskStatus(state=task_state, message=final_message),
[final_artifact],
)
# Send the final artifact
yield SendTaskStreamingResponse(
id=request.id,
result=TaskArtifactUpdateEvent(
id=task_params.id, artifact=final_artifact
),
)
# Send the final status
yield SendTaskStreamingResponse(
id=request.id,
result=TaskStatusUpdateEvent(
id=task_params.id,
status=TaskStatus(state=task_state),
final=True,
),
)
except Exception as e:
logger.error(f"Error streaming task process: {e}")
yield JSONRPCResponse(
id=request.id,
error=InternalError(message=f"Error streaming task process: {str(e)}"),
)
async def update_store(
self,
task_id: str,
status: TaskStatus,
artifacts: Optional[list[Artifact]] = None,
) -> Task:
"""Updates the status and artifacts of a task."""
async with self.lock:
if task_id not in self.tasks:
raise ValueError(f"Task {task_id} not found")
task = self.tasks[task_id]
task.status = status
# Add message to history if it exists
if status.message is not None:
if task.history is None:
task.history = []
task.history.append(status.message)
if artifacts is not None:
if task.artifacts is None:
task.artifacts = []
task.artifacts.extend(artifacts)
return task
def _extract_user_query(self, task_params: TaskSendParams) -> str:
"""Extracts the user query from the task parameters."""
if not task_params.message or not task_params.message.parts:
raise ValueError("Message or parts are missing in task parameters")
part = task_params.message.parts[0]
if part.type != "text":
raise ValueError("Only text parts are supported")
return part.text
async def _run_agent(self, agent: Agent, query: str, session_id: str) -> str:
"""Executes the agent to process the user query."""
try:
# We use the session_id as contact_id to maintain the conversation continuity
contact_id = session_id
# We call the same function used in the chat API
final_response = await run_agent(
agent_id=str(agent.id),
contact_id=contact_id,
message=query,
session_service=session_service,
artifacts_service=artifacts_service,
memory_service=memory_service,
db=self.db,
)
return final_response
except Exception as e:
logger.error(f"Error running agent: {e}")
raise ValueError(f"Error running agent: {str(e)}")
def append_task_history(self, task: Task, history_length: int | None) -> Task:
"""Returns a copy of the task with the history limited to the specified size."""
# Create a copy of the task
new_task = task.model_copy()
# Limit the history if requested
if history_length is not None:
if history_length > 0:
new_task.history = (
new_task.history[-history_length:] if new_task.history else []
)
else:
new_task.history = []
return new_task
class A2AService:
"""Service to manage A2A requests and agent cards."""
def __init__(self, db: Session, task_manager: A2ATaskManager):
self.db = db
self.task_manager = task_manager
async def process_request(
self, agent_id: UUID, request_body: dict
) -> JSONRPCResponse:
"""Processes an A2A request."""
try:
request = A2ARequest.validate_python(request_body)
if isinstance(request, GetTaskRequest):
return await self.task_manager.on_get_task(request)
elif isinstance(request, SendTaskRequest):
return await self.task_manager.on_send_task(request, agent_id)
elif isinstance(request, SendTaskStreamingRequest):
return self.task_manager.on_send_task_subscribe(request, agent_id)
elif isinstance(request, CancelTaskRequest):
return await self.task_manager.on_cancel_task(request)
elif isinstance(request, SetTaskPushNotificationRequest):
return await self.task_manager.on_set_task_push_notification(request)
elif isinstance(request, GetTaskPushNotificationRequest):
return await self.task_manager.on_get_task_push_notification(request)
elif isinstance(request, TaskResubscriptionRequest):
return self.task_manager.on_resubscribe_to_task(request)
else:
logger.warning(f"Unexpected request type: {type(request)}")
return JSONRPCResponse(
id=getattr(request, "id", None),
error=InternalError(
message=f"Unexpected request type: {type(request)}"
),
)
except Exception as e:
logger.error(f"Error processing A2A request: {e}")
return JSONRPCResponse(
id=None,
error=InternalError(message=f"Error processing A2A request: {str(e)}"),
)
def get_agent_card(self, agent_id: UUID) -> AgentCard:
"""Gets the agent card for the specified agent."""
agent = get_agent(self.db, agent_id)
if not agent:
raise ValueError(f"Agent {agent_id} not found")
# Build the agent card based on the agent's information
capabilities = AgentCapabilities(streaming=True)
# List to store all skills
skills = []
# Check if the agent has MCP servers configured
if (
agent.config
and "mcp_servers" in agent.config
and agent.config["mcp_servers"]
):
logger.info(
f"Agent {agent_id} has {len(agent.config['mcp_servers'])} MCP servers configured"
)
for mcp_config in agent.config["mcp_servers"]:
# Get the MCP server
mcp_server_id = mcp_config.get("id")
if not mcp_server_id:
logger.warning("MCP server configuration missing ID")
continue
logger.info(f"Processing MCP server: {mcp_server_id}")
mcp_server = get_mcp_server(self.db, mcp_server_id)
if not mcp_server:
logger.warning(f"MCP server {mcp_server_id} not found")
continue
# Get the available tools in the MCP server
mcp_tools = mcp_config.get("tools", [])
logger.info(f"MCP server {mcp_server.name} has tools: {mcp_tools}")
# Add server tools as skills
for tool_name in mcp_tools:
logger.info(f"Processing tool: {tool_name}")
# Buscar informações da ferramenta pelo ID
tool_info = None
if hasattr(mcp_server, "tools") and isinstance(
mcp_server.tools, list
):
for tool in mcp_server.tools:
if isinstance(tool, dict) and tool.get("id") == tool_name:
tool_info = tool
logger.info(
f"Found tool info for {tool_name}: {tool_info}"
)
break
if tool_info:
# Use the information from the tool
skill = AgentSkill(
id=tool_info.get("id", f"{agent.id}_{tool_name}"),
name=tool_info.get("name", tool_name),
description=tool_info.get(
"description", f"Tool: {tool_name}"
),
tags=tool_info.get(
"tags", [mcp_server.name, "tool", tool_name]
),
examples=tool_info.get("examples", []),
inputModes=tool_info.get("inputModes", ["text"]),
outputModes=tool_info.get("outputModes", ["text"]),
)
else:
# Default skill if tool info not found
skill = AgentSkill(
id=f"{agent.id}_{tool_name}",
name=tool_name,
description=f"Tool: {tool_name}",
tags=[mcp_server.name, "tool", tool_name],
examples=[],
inputModes=["text"],
outputModes=["text"],
)
skills.append(skill)
logger.info(f"Added skill for tool: {tool_name}")
# Check custom tools
if (
agent.config
and "custom_tools" in agent.config
and agent.config["custom_tools"]
):
custom_tools = agent.config["custom_tools"]
# Check HTTP tools
if "http_tools" in custom_tools and custom_tools["http_tools"]:
logger.info(f"Agent has {len(custom_tools['http_tools'])} HTTP tools")
for http_tool in custom_tools["http_tools"]:
skill = AgentSkill(
id=f"{agent.id}_http_{http_tool['name']}",
name=http_tool["name"],
description=http_tool.get(
"description", f"HTTP Tool: {http_tool['name']}"
),
tags=http_tool.get(
"tags", ["http", "custom_tool", http_tool["method"]]
),
examples=http_tool.get("examples", []),
inputModes=http_tool.get("inputModes", ["text"]),
outputModes=http_tool.get("outputModes", ["text"]),
)
skills.append(skill)
logger.info(f"Added skill for HTTP tool: {http_tool['name']}")
card = AgentCard(
name=agent.name,
description=agent.description or "",
url=f"{settings.API_URL}/api/v1/a2a/{agent_id}",
version="1.0.0",
defaultInputModes=["text"],
defaultOutputModes=["text"],
capabilities=capabilities,
skills=skills,
)
logger.info(f"Generated agent card with {len(skills)} skills")
return card

View File

@ -1,949 +0,0 @@
"""
A2A Task Manager Service.
This service implements task management for the A2A protocol, handling task lifecycle
including execution, streaming, push notifications, status queries, and cancellation.
"""
import asyncio
import logging
from datetime import datetime
from typing import Any, Dict, Union, AsyncIterable
import uuid
from src.schemas.a2a.exceptions import (
TaskNotFoundError,
TaskNotCancelableError,
PushNotificationNotSupportedError,
InternalError,
ContentTypeNotSupportedError,
)
from src.schemas.a2a.types import (
JSONRPCResponse,
GetTaskRequest,
SendTaskRequest,
CancelTaskRequest,
SetTaskPushNotificationRequest,
GetTaskPushNotificationRequest,
GetTaskResponse,
CancelTaskResponse,
SendTaskResponse,
SetTaskPushNotificationResponse,
GetTaskPushNotificationResponse,
TaskSendParams,
TaskStatus,
TaskState,
TaskResubscriptionRequest,
SendTaskStreamingRequest,
SendTaskStreamingResponse,
Artifact,
PushNotificationConfig,
TaskStatusUpdateEvent,
TaskArtifactUpdateEvent,
JSONRPCError,
TaskPushNotificationConfig,
Message,
TextPart,
Task,
)
from src.services.redis_cache_service import RedisCacheService
from src.utils.a2a_utils import (
are_modalities_compatible,
new_incompatible_types_error,
)
logger = logging.getLogger(__name__)
class A2ATaskManager:
"""
A2A Task Manager implementation.
This class manages the lifecycle of A2A tasks, including:
- Task submission and execution
- Task status queries
- Task cancellation
- Push notification configuration
- SSE streaming for real-time updates
"""
def __init__(
self,
redis_cache: RedisCacheService,
agent_runner=None,
streaming_service=None,
push_notification_service=None,
db=None,
):
"""
Initialize the A2A Task Manager.
Args:
redis_cache: Redis cache service for task storage
agent_runner: Agent runner service for task execution
streaming_service: Streaming service for SSE
push_notification_service: Service for push notifications
db: Database session
"""
self.redis_cache = redis_cache
self.agent_runner = agent_runner
self.streaming_service = streaming_service
self.push_notification_service = push_notification_service
self.db = db
self.lock = asyncio.Lock()
self.subscriber_lock = asyncio.Lock()
self.task_sse_subscribers = {}
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
"""
Handle request to get task information.
Args:
request: A2A Get Task request
Returns:
Response with task details
"""
try:
task_id = request.params.id
history_length = request.params.historyLength
# Get task data from cache
task_data = await self.redis_cache.get(f"task:{task_id}")
if not task_data:
logger.warning(f"Task not found: {task_id}")
return GetTaskResponse(id=request.id, error=TaskNotFoundError())
# Create a Task instance from cache data
task = Task.model_validate(task_data)
# If historyLength parameter is present, handle the history
if history_length is not None and task.history:
if history_length == 0:
task.history = []
elif len(task.history) > history_length:
task.history = task.history[-history_length:]
return GetTaskResponse(id=request.id, result=task)
except Exception as e:
logger.error(f"Error processing on_get_task: {str(e)}")
return GetTaskResponse(id=request.id, error=InternalError(message=str(e)))
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
"""
Handle request to cancel a running task.
Args:
request: The JSON-RPC request to cancel a task
Returns:
Response with updated task data or error
"""
logger.info(f"Cancelling task {request.params.id}")
task_id_params = request.params
try:
task_data = await self.redis_cache.get(f"task:{task_id_params.id}")
if not task_data:
logger.warning(f"Task {task_id_params.id} not found for cancellation")
return CancelTaskResponse(id=request.id, error=TaskNotFoundError())
# Check if task can be cancelled
current_state = task_data.get("status", {}).get("state")
if current_state not in [TaskState.SUBMITTED, TaskState.WORKING]:
logger.warning(
f"Task {task_id_params.id} in state {current_state} cannot be cancelled"
)
return CancelTaskResponse(id=request.id, error=TaskNotCancelableError())
# Update task status to cancelled
task_data["status"] = {
"state": TaskState.CANCELED,
"timestamp": datetime.now().isoformat(),
"message": {
"role": "agent",
"parts": [{"type": "text", "text": "Task cancelled by user"}],
},
}
# Save updated task data
await self.redis_cache.set(f"task:{task_id_params.id}", task_data)
# Send push notification if configured
await self._send_push_notification_for_task(
task_id_params.id, "canceled", system_message="Task cancelled by user"
)
# Publish event to SSE subscribers
await self._publish_task_update(
task_id_params.id,
TaskStatusUpdateEvent(
id=task_id_params.id,
status=TaskStatus(
state=TaskState.CANCELED,
timestamp=datetime.now(),
message=Message(
role="agent",
parts=[TextPart(text="Task cancelled by user")],
),
),
final=True,
),
)
return CancelTaskResponse(id=request.id, result=task_data)
except Exception as e:
logger.error(f"Error cancelling task: {str(e)}", exc_info=True)
return CancelTaskResponse(
id=request.id,
error=InternalError(message=f"Error cancelling task: {str(e)}"),
)
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
"""
Handle request to send a new task.
Args:
request: Send Task request
Returns:
Response with the created task details
"""
try:
params = request.params
task_id = params.id
logger.info(f"Receiving task {task_id}")
# Check if a task with this ID already exists
existing_task = await self.redis_cache.get(f"task:{task_id}")
if existing_task:
# If the task already exists and is in progress, return the current task
if existing_task.get("status", {}).get("state") in [
TaskState.WORKING,
TaskState.COMPLETED,
]:
return SendTaskResponse(
id=request.id, result=Task.model_validate(existing_task)
)
# If the task exists but failed or was canceled, we can reprocess it
logger.info(f"Reprocessing existing task {task_id}")
# Check modality compatibility
server_output_modes = []
if self.agent_runner:
# Try to get supported modes from the agent
try:
server_output_modes = await self.agent_runner.get_supported_modes()
except Exception as e:
logger.warning(f"Error getting supported modes: {str(e)}")
server_output_modes = ["text"] # Fallback to text
if not are_modalities_compatible(
server_output_modes, params.acceptedOutputModes
):
logger.warning(
f"Incompatible modes: server={server_output_modes}, client={params.acceptedOutputModes}"
)
return SendTaskResponse(
id=request.id, error=ContentTypeNotSupportedError()
)
# Create task data
task_data = await self._create_task_data(params)
# Store task in cache
await self.redis_cache.set(f"task:{task_id}", task_data)
# Configure push notifications, if provided
if params.pushNotification:
await self.redis_cache.set(
f"task_notification:{task_id}", params.pushNotification.model_dump()
)
# Execute task SYNCHRONOUSLY instead of in background
# This is the key change for A2A compatibility
task_data = await self._execute_task(task_data, params)
# Convert to Task object and return
task = Task.model_validate(task_data)
return SendTaskResponse(id=request.id, result=task)
except Exception as e:
logger.error(f"Error processing on_send_task: {str(e)}")
return SendTaskResponse(id=request.id, error=InternalError(message=str(e)))
async def on_send_task_subscribe(
self, request: SendTaskStreamingRequest
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
"""
Handle request to send a task and subscribe to streaming updates.
Args:
request: The JSON-RPC request to send a task with streaming
Returns:
Stream of events or error response
"""
logger.info(f"Sending task with streaming {request.params.id}")
task_send_params = request.params
try:
# Check output mode compatibility
if not are_modalities_compatible(
["text", "application/json"], # Default supported modes
task_send_params.acceptedOutputModes,
):
return new_incompatible_types_error(request.id)
# Create initial task data
task_data = await self._create_task_data(task_send_params)
# Setup SSE consumer
sse_queue = await self._setup_sse_consumer(task_send_params.id)
# Execute task asynchronously (fire and forget)
asyncio.create_task(self._execute_task(task_data, task_send_params))
# Return generator to dequeue events for SSE
return self._dequeue_events_for_sse(
request.id, task_send_params.id, sse_queue
)
except Exception as e:
logger.error(f"Error setting up streaming task: {str(e)}", exc_info=True)
return SendTaskStreamingResponse(
id=request.id,
error=InternalError(
message=f"Error setting up streaming task: {str(e)}"
),
)
async def on_set_task_push_notification(
self, request: SetTaskPushNotificationRequest
) -> SetTaskPushNotificationResponse:
"""
Configure push notifications for a task.
Args:
request: The JSON-RPC request to set push notification
Returns:
Response with configuration or error
"""
logger.info(f"Setting push notification for task {request.params.id}")
task_notification_params = request.params
try:
if not self.push_notification_service:
logger.warning("Push notifications not supported")
return SetTaskPushNotificationResponse(
id=request.id, error=PushNotificationNotSupportedError()
)
# Check if task exists
task_data = await self.redis_cache.get(
f"task:{task_notification_params.id}"
)
if not task_data:
logger.warning(
f"Task {task_notification_params.id} not found for setting push notification"
)
return SetTaskPushNotificationResponse(
id=request.id, error=TaskNotFoundError()
)
# Save push notification config
config = {
"url": task_notification_params.pushNotificationConfig.url,
"headers": {}, # Add auth headers if needed
}
await self.redis_cache.set(
f"task:{task_notification_params.id}:push", config
)
return SetTaskPushNotificationResponse(
id=request.id, result=task_notification_params
)
except Exception as e:
logger.error(f"Error setting push notification: {str(e)}", exc_info=True)
return SetTaskPushNotificationResponse(
id=request.id,
error=InternalError(
message=f"Error setting push notification: {str(e)}"
),
)
async def on_get_task_push_notification(
self, request: GetTaskPushNotificationRequest
) -> GetTaskPushNotificationResponse:
"""
Get push notification configuration for a task.
Args:
request: The JSON-RPC request to get push notification config
Returns:
Response with configuration or error
"""
logger.info(f"Getting push notification for task {request.params.id}")
task_params = request.params
try:
if not self.push_notification_service:
logger.warning("Push notifications not supported")
return GetTaskPushNotificationResponse(
id=request.id, error=PushNotificationNotSupportedError()
)
# Check if task exists
task_data = await self.redis_cache.get(f"task:{task_params.id}")
if not task_data:
logger.warning(
f"Task {task_params.id} not found for getting push notification"
)
return GetTaskPushNotificationResponse(
id=request.id, error=TaskNotFoundError()
)
# Get push notification config
config = await self.redis_cache.get(f"task:{task_params.id}:push")
if not config:
logger.warning(f"No push notification config for task {task_params.id}")
return GetTaskPushNotificationResponse(
id=request.id,
error=InternalError(
message="No push notification configuration found"
),
)
result = TaskPushNotificationConfig(
id=task_params.id,
pushNotificationConfig=PushNotificationConfig(
url=config.get("url"), token=None, authentication=None
),
)
return GetTaskPushNotificationResponse(id=request.id, result=result)
except Exception as e:
logger.error(f"Error getting push notification: {str(e)}", exc_info=True)
return GetTaskPushNotificationResponse(
id=request.id,
error=InternalError(
message=f"Error getting push notification: {str(e)}"
),
)
async def on_resubscribe_to_task(
self, request: TaskResubscriptionRequest
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
"""
Resubscribe to a task's streaming events.
Args:
request: The JSON-RPC request to resubscribe
Returns:
Stream of events or error response
"""
logger.info(f"Resubscribing to task {request.params.id}")
task_params = request.params
try:
# Check if task exists
task_data = await self.redis_cache.get(f"task:{task_params.id}")
if not task_data:
logger.warning(f"Task {task_params.id} not found for resubscription")
return JSONRPCResponse(id=request.id, error=TaskNotFoundError())
# Setup SSE consumer with resubscribe flag
try:
sse_queue = await self._setup_sse_consumer(
task_params.id, is_resubscribe=True
)
except ValueError:
logger.warning(
f"Task {task_params.id} not available for resubscription"
)
return JSONRPCResponse(
id=request.id,
error=InternalError(
message="Task not available for resubscription"
),
)
# Send initial status update to the new subscriber
status = task_data.get("status", {})
final = status.get("state") in [
TaskState.COMPLETED,
TaskState.FAILED,
TaskState.CANCELED,
]
# Convert to TaskStatus object
task_status = TaskStatus(
state=status.get("state", TaskState.UNKNOWN),
timestamp=datetime.fromisoformat(
status.get("timestamp", datetime.now().isoformat())
),
message=status.get("message"),
)
# Publish to the specific queue
await sse_queue.put(
TaskStatusUpdateEvent(
id=task_params.id, status=task_status, final=final
)
)
# Return generator to dequeue events for SSE
return self._dequeue_events_for_sse(request.id, task_params.id, sse_queue)
except Exception as e:
logger.error(f"Error resubscribing to task: {str(e)}", exc_info=True)
return JSONRPCResponse(
id=request.id,
error=InternalError(message=f"Error resubscribing to task: {str(e)}"),
)
async def _create_task_data(self, params: TaskSendParams) -> Dict[str, Any]:
"""
Create initial task data structure.
Args:
params: Task send parameters
Returns:
Task data dictionary
"""
# Create task with initial status
task_data = {
"id": params.id,
"sessionId": params.sessionId
or str(uuid.uuid4()), # Preservar sessionId quando fornecido
"status": {
"state": TaskState.SUBMITTED,
"timestamp": datetime.now().isoformat(),
"message": None,
"error": None,
},
"artifacts": [],
"history": [params.message.model_dump()], # Apenas mensagem do usuário
"metadata": params.metadata or {},
}
# Save task data
await self.redis_cache.set(f"task:{params.id}", task_data)
return task_data
async def _execute_task(
self, task: Dict[str, Any], params: TaskSendParams
) -> Dict[str, Any]:
"""
Execute a task using the agent adapter.
This function is responsible for executing the task by the agent,
updating its status as progress is made.
Args:
task: Task data to be executed
params: Send task parameters
Returns:
Updated task data with completed status and response
"""
task_id = task["id"]
agent_id = params.agentId
message_text = ""
# Extract the text from the message
if params.message and params.message.parts:
for part in params.message.parts:
if part.type == "text":
message_text += part.text
if not message_text:
await self._update_task_status_without_history(
task_id, TaskState.FAILED, "Message does not contain text", final=True
)
# Return the updated task data
return await self.redis_cache.get(f"task:{task_id}")
# Check if it is an ongoing execution
task_status = task.get("status", {})
if task_status.get("state") in [TaskState.WORKING, TaskState.COMPLETED]:
logger.info(f"Task {task_id} is already in execution or completed")
# Return the current task data
return await self.redis_cache.get(f"task:{task_id}")
try:
# Update to "working" state - NÃO adicionar ao histórico
await self._update_task_status_without_history(
task_id, TaskState.WORKING, "Processing request"
)
# Execute the agent
if self.agent_runner:
response = await self.agent_runner.run_agent(
agent_id=agent_id,
message=message_text,
session_id=params.sessionId, # Usar o sessionId da requisição
task_id=task_id,
)
# Process the agent's response
if response and isinstance(response, dict):
# Extract text from the response
response_text = response.get("content", "")
if not response_text and "message" in response:
message = response.get("message", {})
parts = message.get("parts", [])
for part in parts:
if part.get("type") == "text":
response_text += part.get("text", "")
# Build the final agent message
if response_text:
# Atualizar o histórico com a mensagem do usuário
await self._update_task_history(task_id, params.message)
# Create an artifact for the response in Google A2A format
artifact = {
"parts": [{"type": "text", "text": response_text}],
"index": 0,
}
# Add the artifact to the task
await self._add_task_artifact(task_id, artifact)
# Update the task status to completed (sem adicionar ao histórico)
await self._update_task_status_without_history(
task_id, TaskState.COMPLETED, response_text, final=True
)
else:
await self._update_task_status_without_history(
task_id,
TaskState.FAILED,
"The agent did not return a valid response",
final=True,
)
else:
await self._update_task_status_without_history(
task_id,
TaskState.FAILED,
"Invalid agent response",
final=True,
)
else:
await self._update_task_status_without_history(
task_id,
TaskState.FAILED,
"Agent adapter not configured",
final=True,
)
except Exception as e:
logger.error(f"Error executing task {task_id}: {str(e)}")
await self._update_task_status_without_history(
task_id, TaskState.FAILED, f"Error processing: {str(e)}", final=True
)
# Return the updated task data
return await self.redis_cache.get(f"task:{task_id}")
async def _update_task_history(self, task_id: str, message) -> None:
"""
Atualiza o histórico da tarefa incluindo apenas mensagens do usuário.
Args:
task_id: ID da tarefa
message: Mensagem do usuário para adicionar ao histórico
"""
if not message:
return
# Obter dados da tarefa atual
task_data = await self.redis_cache.get(f"task:{task_id}")
if not task_data:
logger.warning(f"Task {task_id} not found for history update")
return
# Garantir que há um campo de histórico
if "history" not in task_data:
task_data["history"] = []
# Verificar se a mensagem já existe no histórico para evitar duplicação
user_message = (
message.model_dump() if hasattr(message, "model_dump") else message
)
message_exists = False
for msg in task_data["history"]:
if self._compare_messages(msg, user_message):
message_exists = True
break
# Adicionar mensagem se não existir
if not message_exists:
task_data["history"].append(user_message)
logger.info(f"Added new user message to history for task {task_id}")
# Salvar tarefa atualizada
await self.redis_cache.set(f"task:{task_id}", task_data)
# Método auxiliar para comparar mensagens e evitar duplicação no histórico
def _compare_messages(self, msg1: Dict, msg2: Dict) -> bool:
"""Compara duas mensagens para verificar se são essencialmente iguais."""
if msg1.get("role") != msg2.get("role"):
return False
parts1 = msg1.get("parts", [])
parts2 = msg2.get("parts", [])
if len(parts1) != len(parts2):
return False
for i in range(len(parts1)):
if parts1[i].get("type") != parts2[i].get("type"):
return False
if parts1[i].get("text") != parts2[i].get("text"):
return False
return True
async def _update_task_status_without_history(
self, task_id: str, state: TaskState, message_text: str, final: bool = False
) -> None:
"""
Update the status of a task without changing the history.
Args:
task_id: ID of the task to be updated
state: New task state
message_text: Text of the message associated with the status
final: Indicates if this is the final status of the task
"""
try:
# Get current task data
task_data = await self.redis_cache.get(f"task:{task_id}")
if not task_data:
logger.warning(f"Unable to update status: task {task_id} not found")
return
# Create status object with the message
agent_message = Message(
role="agent",
parts=[TextPart(text=message_text)],
)
status = TaskStatus(
state=state, message=agent_message, timestamp=datetime.now()
)
# Update the status in the task
task_data["status"] = status.model_dump(exclude_none=True)
# Store the updated task
await self.redis_cache.set(f"task:{task_id}", task_data)
# Create status update event
status_event = TaskStatusUpdateEvent(id=task_id, status=status, final=final)
# Publish status update
await self._publish_task_update(task_id, status_event)
# Send push notification, if configured
if final or state in [
TaskState.FAILED,
TaskState.COMPLETED,
TaskState.CANCELED,
]:
await self._send_push_notification_for_task(
task_id=task_id, state=state, message_text=message_text
)
except Exception as e:
logger.error(f"Error updating task status {task_id}: {str(e)}")
async def _add_task_artifact(self, task_id: str, artifact) -> None:
"""
Add an artifact to a task and publish the update.
Args:
task_id: Task ID
artifact: Artifact to add (dict no formato do Google)
"""
logger.info(f"Adding artifact to task {task_id}")
# Update task data
task_data = await self.redis_cache.get(f"task:{task_id}")
if task_data:
if "artifacts" not in task_data:
task_data["artifacts"] = []
# Adicionar o artefato sem substituir os existentes
task_data["artifacts"].append(artifact)
await self.redis_cache.set(f"task:{task_id}", task_data)
# Criar um artefato do tipo Artifact para o evento
artifact_obj = Artifact(
parts=[
TextPart(text=part.get("text", ""))
for part in artifact.get("parts", [])
],
index=artifact.get("index", 0),
)
# Create artifact update event
event = TaskArtifactUpdateEvent(id=task_id, artifact=artifact_obj)
# Publish event
await self._publish_task_update(task_id, event)
async def _publish_task_update(
self, task_id: str, event: Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent]
) -> None:
"""
Publish a task update event to all subscribers.
Args:
task_id: Task ID
event: Event to publish
"""
async with self.subscriber_lock:
if task_id not in self.task_sse_subscribers:
return
subscribers = self.task_sse_subscribers[task_id]
for subscriber in subscribers:
try:
await subscriber.put(event)
except Exception as e:
logger.error(f"Error publishing event to subscriber: {str(e)}")
async def _send_push_notification_for_task(
self,
task_id: str,
state: str,
message_text: str = None,
system_message: str = None,
) -> None:
"""
Send push notification for a task if configured.
Args:
task_id: Task ID
state: Task state
message_text: Optional message text
system_message: Optional system message
"""
if not self.push_notification_service:
return
try:
# Get push notification config
config = await self.redis_cache.get(f"task:{task_id}:push")
if not config:
return
# Create message if provided
message = None
if message_text:
message = {
"role": "agent",
"parts": [{"type": "text", "text": message_text}],
}
elif system_message:
# We use 'agent' instead of 'system' since Message only accepts 'user' or 'agent'
message = {
"role": "agent",
"parts": [{"type": "text", "text": system_message}],
}
# Send notification
await self.push_notification_service.send_notification(
url=config["url"],
task_id=task_id,
state=state,
message=message,
headers=config.get("headers", {}),
)
except Exception as e:
logger.error(
f"Error sending push notification for task {task_id}: {str(e)}"
)
async def _setup_sse_consumer(
self, task_id: str, is_resubscribe: bool = False
) -> asyncio.Queue:
"""
Set up an SSE consumer queue for a task.
Args:
task_id: Task ID
is_resubscribe: Whether this is a resubscription
Returns:
Queue for events
Raises:
ValueError: If resubscribing to non-existent task
"""
async with self.subscriber_lock:
if task_id not in self.task_sse_subscribers:
if is_resubscribe:
raise ValueError("Task not found for resubscription")
self.task_sse_subscribers[task_id] = []
queue = asyncio.Queue()
self.task_sse_subscribers[task_id].append(queue)
return queue
async def _dequeue_events_for_sse(
self, request_id: str, task_id: str, event_queue: asyncio.Queue
) -> AsyncIterable[SendTaskStreamingResponse]:
"""
Dequeue and yield events for SSE streaming.
Args:
request_id: Request ID
task_id: Task ID
event_queue: Queue for events
Yields:
SSE events wrapped in SendTaskStreamingResponse
"""
try:
while True:
event = await event_queue.get()
if isinstance(event, JSONRPCError):
yield SendTaskStreamingResponse(id=request_id, error=event)
break
yield SendTaskStreamingResponse(id=request_id, result=event)
# Check if this is the final event
is_final = False
if hasattr(event, "final") and event.final:
is_final = True
if is_final:
break
finally:
# Clean up the subscription when done
async with self.subscriber_lock:
if task_id in self.task_sse_subscribers:
try:
self.task_sse_subscribers[task_id].remove(event_queue)
# Remove the task from the dict if no more subscribers
if not self.task_sse_subscribers[task_id]:
del self.task_sse_subscribers[task_id]
except ValueError:
pass # Queue might have been removed already

View File

@ -1,265 +0,0 @@
"""
Push Notification Authentication Service.
This service implements JWT authentication for A2A push notifications,
allowing agents to send authenticated notifications and clients to verify
the authenticity of received notifications.
"""
from jwcrypto import jwk
import uuid
import time
import json
import hashlib
import httpx
import logging
import jwt
from jwt import PyJWK, PyJWKClient
from fastapi import Request
from starlette.responses import JSONResponse
from typing import Dict, Any
logger = logging.getLogger(__name__)
AUTH_HEADER_PREFIX = "Bearer "
class PushNotificationAuth:
"""
Base class for push notification authentication.
Contains common methods used by both the sender and the receiver
of push notifications.
"""
def _calculate_request_body_sha256(self, data: Dict[str, Any]) -> str:
"""
Calculates the SHA256 hash of the request body.
This logic needs to be the same for the agent that signs the payload
and for the client that verifies it.
Args:
data: Request body data
Returns:
SHA256 hash as a hexadecimal string
"""
body_str = json.dumps(
data,
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
)
return hashlib.sha256(body_str.encode()).hexdigest()
class PushNotificationSenderAuth(PushNotificationAuth):
"""
Authentication for the push notification sender.
This class is used by the A2A server to authenticate notifications
sent to callback URLs of clients.
"""
def __init__(self):
"""
Initialize the push notification sender authentication service.
"""
self.public_keys = []
self.private_key_jwk = None
@staticmethod
async def verify_push_notification_url(url: str) -> bool:
"""
Verifies if a push notification URL is valid and responds correctly.
Sends a validation token and verifies if the response contains the same token.
Args:
url: URL to be verified
Returns:
True if the URL is verified successfully, False otherwise
"""
async with httpx.AsyncClient(timeout=10) as client:
try:
validation_token = str(uuid.uuid4())
response = await client.get(
url, params={"validationToken": validation_token}
)
response.raise_for_status()
is_verified = response.text == validation_token
logger.info(f"Push notification URL verified: {url} => {is_verified}")
return is_verified
except Exception as e:
logger.warning(f"Error verifying push notification URL {url}: {e}")
return False
def generate_jwk(self):
"""
Generates a new JWK pair for signing.
The key pair is used to sign push notifications.
The public key is available via the JWKS endpoint.
"""
key = jwk.JWK.generate(kty="RSA", size=2048, kid=str(uuid.uuid4()), use="sig")
self.public_keys.append(key.export_public(as_dict=True))
self.private_key_jwk = PyJWK.from_json(key.export_private())
def handle_jwks_endpoint(self, _request: Request) -> JSONResponse:
"""
Handles the JWKS endpoint to allow clients to obtain the public keys.
Args:
_request: HTTP request
Returns:
JSON response with the public keys
"""
return JSONResponse({"keys": self.public_keys})
def _generate_jwt(self, data: Dict[str, Any]) -> str:
"""
Generates a JWT token by signing the SHA256 hash of the payload and the timestamp.
The payload is signed with the private key to ensure integrity.
The timestamp (iat) prevents replay attacks.
Args:
data: Payload data
Returns:
Signed JWT token
"""
iat = int(time.time())
return jwt.encode(
{
"iat": iat,
"request_body_sha256": self._calculate_request_body_sha256(data),
},
key=self.private_key_jwk.key,
headers={"kid": self.private_key_jwk.key_id},
algorithm="RS256",
)
async def send_push_notification(self, url: str, data: Dict[str, Any]) -> bool:
"""
Sends an authenticated push notification to the specified URL.
Args:
url: URL to send the notification
data: Notification data
Returns:
True if the notification was sent successfully, False otherwise
"""
if not self.private_key_jwk:
logger.error(
"Attempt to send push notification without generating JWK keys"
)
return False
try:
jwt_token = self._generate_jwt(data)
headers = {"Authorization": f"Bearer {jwt_token}"}
async with httpx.AsyncClient(timeout=10) as client:
response = await client.post(url, json=data, headers=headers)
response.raise_for_status()
logger.info(f"Push notification sent to URL: {url}")
return True
except Exception as e:
logger.warning(f"Error sending push notification to URL {url}: {e}")
return False
class PushNotificationReceiverAuth(PushNotificationAuth):
"""
Authentication for the push notification receiver.
This class is used by clients to verify the authenticity
of notifications received from the A2A server.
"""
def __init__(self):
"""
Initialize the push notification receiver authentication service.
"""
self.jwks_client = None
async def load_jwks(self, jwks_url: str):
"""
Loads the public JWKS keys from a URL.
Args:
jwks_url: URL of the JWKS endpoint
"""
self.jwks_client = PyJWKClient(jwks_url)
async def verify_push_notification(self, request: Request) -> bool:
"""
Verifies the authenticity of a push notification.
Args:
request: HTTP request containing the notification
Returns:
True if the notification is authentic, False otherwise
Raises:
ValueError: If the token is invalid or expired
"""
if not self.jwks_client:
logger.error("Attempt to verify notification without loading JWKS keys")
return False
# Verify authentication header
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX):
logger.warning("Invalid authorization header")
return False
try:
# Extract and verify token
token = auth_header[len(AUTH_HEADER_PREFIX) :]
signing_key = self.jwks_client.get_signing_key_from_jwt(token)
# Decode token
decode_token = jwt.decode(
token,
signing_key.key,
options={"require": ["iat", "request_body_sha256"]},
algorithms=["RS256"],
)
# Verify request body integrity
body_data = await request.json()
actual_body_sha256 = self._calculate_request_body_sha256(body_data)
if actual_body_sha256 != decode_token["request_body_sha256"]:
# The payload signature does not match the hash in the signed token
logger.warning("Request body hash does not match the token")
raise ValueError("Invalid request body")
# Verify token age (maximum 5 minutes)
if time.time() - decode_token["iat"] > 60 * 5:
# Do not allow notifications older than 5 minutes
# This prevents replay attacks
logger.warning("Token expired")
raise ValueError("Token expired")
return True
except Exception as e:
logger.error(f"Error verifying push notification: {e}")
return False
# Global instance of the push notification sender authentication service
push_notification_auth = PushNotificationSenderAuth()
# Generate JWK keys on initialization
push_notification_auth.generate_jwk()

View File

@ -1,99 +0,0 @@
import aiohttp
import logging
from datetime import datetime
from typing import Dict, Any, Optional
import asyncio
from src.services.push_notification_auth_service import push_notification_auth
logger = logging.getLogger(__name__)
class PushNotificationService:
def __init__(self):
self.session = aiohttp.ClientSession()
async def send_notification(
self,
url: str,
task_id: str,
state: str,
message: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
max_retries: int = 3,
retry_delay: float = 1.0,
use_jwt_auth: bool = True,
) -> bool:
"""
Send a push notification to the specified URL.
Implements exponential backoff retry.
Args:
url: URL to send the notification
task_id: Task ID
state: Task state
message: Optional message
headers: Optional HTTP headers
max_retries: Maximum number of retries
retry_delay: Initial delay between retries
use_jwt_auth: Whether to use JWT authentication
Returns:
True if the notification was sent successfully, False otherwise
"""
payload = {
"taskId": task_id,
"state": state,
"timestamp": datetime.now().isoformat(),
"message": message,
}
# First URL verification
if use_jwt_auth:
is_url_valid = await push_notification_auth.verify_push_notification_url(
url
)
if not is_url_valid:
logger.warning(f"Invalid push notification URL: {url}")
return False
for attempt in range(max_retries):
try:
if use_jwt_auth:
# Use JWT authentication
success = await push_notification_auth.send_push_notification(
url, payload
)
if success:
return True
else:
# Legacy method without JWT authentication
async with self.session.post(
url, json=payload, headers=headers or {}, timeout=10
) as response:
if response.status in (200, 201, 202, 204):
logger.info(f"Push notification sent to URL: {url}")
return True
else:
logger.warning(
f"Failed to send push notification with status {response.status}. "
f"Attempt {attempt + 1}/{max_retries}"
)
except Exception as e:
logger.error(
f"Error sending push notification: {str(e)}. "
f"Attempt {attempt + 1}/{max_retries}"
)
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay * (2**attempt))
return False
async def close(self):
"""Close the HTTP session"""
await self.session.close()
# Global instance of the push notification service
push_notification_service = PushNotificationService()

View File

@ -1,555 +0,0 @@
"""
Cache Redis service for the A2A protocol.
This service provides an interface for storing and retrieving data related to tasks,
push notification configurations, and other A2A-related data.
"""
import json
import logging
from typing import Any, Dict, List, Optional
import asyncio
import redis.asyncio as aioredis
from src.config.redis import get_redis_config
import threading
import time
logger = logging.getLogger(__name__)
class _InMemoryCacheFallback:
"""
Fallback in-memory cache implementation for when Redis is not available.
This should only be used for development or testing environments.
"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
"""Singleton implementation."""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
"""Initialize cache storage."""
if not getattr(self, "_initialized", False):
with self._lock:
if not getattr(self, "_initialized", False):
self._data = {}
self._ttls = {}
self._hash_data = {}
self._list_data = {}
self._data_lock = threading.Lock()
self._initialized = True
logger.warning(
"Initializing in-memory cache fallback (not for production)"
)
async def set(self, key, value, ex=None):
"""Set a key with optional expiration."""
with self._data_lock:
self._data[key] = value
if ex is not None:
self._ttls[key] = time.time() + ex
elif key in self._ttls:
del self._ttls[key]
return True
async def setex(self, key, ex, value):
"""Set a key with expiration."""
return await self.set(key, value, ex)
async def get(self, key):
"""Get a key value."""
with self._data_lock:
# Check if expired
if key in self._ttls and time.time() > self._ttls[key]:
del self._data[key]
del self._ttls[key]
return None
return self._data.get(key)
async def delete(self, key):
"""Delete a key."""
with self._data_lock:
if key in self._data:
del self._data[key]
if key in self._ttls:
del self._ttls[key]
return 1
return 0
async def exists(self, key):
"""Check if key exists."""
with self._data_lock:
if key in self._ttls and time.time() > self._ttls[key]:
del self._data[key]
del self._ttls[key]
return 0
return 1 if key in self._data else 0
async def hset(self, key, field, value):
"""Set a hash field."""
with self._data_lock:
if key not in self._hash_data:
self._hash_data[key] = {}
self._hash_data[key][field] = value
return 1
async def hget(self, key, field):
"""Get a hash field."""
with self._data_lock:
if key not in self._hash_data:
return None
return self._hash_data[key].get(field)
async def hdel(self, key, field):
"""Delete a hash field."""
with self._data_lock:
if key in self._hash_data and field in self._hash_data[key]:
del self._hash_data[key][field]
return 1
return 0
async def hgetall(self, key):
"""Get all hash fields."""
with self._data_lock:
if key not in self._hash_data:
return {}
return dict(self._hash_data[key])
async def rpush(self, key, value):
"""Add to a list."""
with self._data_lock:
if key not in self._list_data:
self._list_data[key] = []
self._list_data[key].append(value)
return len(self._list_data[key])
async def lrange(self, key, start, end):
"""Get range from list."""
with self._data_lock:
if key not in self._list_data:
return []
# Handle negative indices
if end < 0:
end = len(self._list_data[key]) + end + 1
return self._list_data[key][start:end]
async def expire(self, key, seconds):
"""Set expiration on key."""
with self._data_lock:
if key in self._data:
self._ttls[key] = time.time() + seconds
return 1
return 0
async def flushdb(self):
"""Clear all data."""
with self._data_lock:
self._data.clear()
self._ttls.clear()
self._hash_data.clear()
self._list_data.clear()
return True
async def keys(self, pattern="*"):
"""Get keys matching pattern."""
with self._data_lock:
# Clean expired keys
now = time.time()
expired_keys = [k for k, exp in self._ttls.items() if now > exp]
for k in expired_keys:
if k in self._data:
del self._data[k]
del self._ttls[k]
# Simple pattern matching
result = []
if pattern == "*":
result = list(self._data.keys())
elif pattern.endswith("*"):
prefix = pattern[:-1]
result = [k for k in self._data.keys() if k.startswith(prefix)]
elif pattern.startswith("*"):
suffix = pattern[1:]
result = [k for k in self._data.keys() if k.endswith(suffix)]
else:
if pattern in self._data:
result = [pattern]
return result
async def ping(self):
"""Test connection."""
return True
class RedisCacheService:
"""
Cache service using Redis for storing A2A-related data.
This implementation uses a real Redis connection for distributed caching
and data persistence across multiple instances.
If Redis is not available, falls back to an in-memory implementation.
"""
def __init__(self, redis_url: Optional[str] = None):
"""
Initialize the Redis cache service.
Args:
redis_url: Redis server URL (optional, defaults to config value)
"""
if redis_url:
self._redis_url = redis_url
else:
# Construir URL a partir dos componentes de configuração
config = get_redis_config()
protocol = "rediss" if config.get("ssl", False) else "redis"
auth = f":{config['password']}@" if config.get("password") else ""
self._redis_url = (
f"{protocol}://{auth}{config['host']}:{config['port']}/{config['db']}"
)
self._redis = None
self._in_memory_mode = False
self._connecting = False
self._connection_lock = asyncio.Lock()
logger.info(f"Initializing RedisCacheService with URL: {self._redis_url}")
async def _get_redis(self):
"""
Get a Redis connection, creating it if necessary.
Falls back to in-memory implementation if Redis is not available.
Returns:
Redis connection or in-memory fallback
"""
if self._redis is not None:
return self._redis
async with self._connection_lock:
if self._redis is None and not self._connecting:
try:
self._connecting = True
logger.info(f"Connecting to Redis at {self._redis_url}")
self._redis = aioredis.from_url(
self._redis_url, encoding="utf-8", decode_responses=True
)
# Teste de conexão
await self._redis.ping()
logger.info("Redis connection successful")
except Exception as e:
logger.error(f"Error connecting to Redis: {str(e)}")
logger.warning(
"Falling back to in-memory cache (not suitable for production)"
)
self._redis = _InMemoryCacheFallback()
self._in_memory_mode = True
finally:
self._connecting = False
return self._redis
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
"""
Store a value in the cache.
Args:
key: Key for the value
value: Value to store
ttl: Time to live in seconds (optional)
"""
try:
redis = await self._get_redis()
# Convert dict/list to JSON string
if isinstance(value, (dict, list)):
value = json.dumps(value)
if ttl:
await redis.setex(key, ttl, value)
else:
await redis.set(key, value)
logger.debug(f"Set cache key: {key}")
except Exception as e:
logger.error(f"Error setting Redis key {key}: {str(e)}")
async def get(self, key: str) -> Optional[Any]:
"""
Retrieve a value from the cache.
Args:
key: Key for the value to retrieve
Returns:
The stored value or None if not found
"""
try:
redis = await self._get_redis()
value = await redis.get(key)
if value is None:
return None
try:
# Try to parse as JSON
return json.loads(value)
except json.JSONDecodeError:
# Return as-is if not JSON
return value
except Exception as e:
logger.error(f"Error getting Redis key {key}: {str(e)}")
return None
async def delete(self, key: str) -> bool:
"""
Remove a value from the cache.
Args:
key: Key for the value to remove
Returns:
True if the value was removed, False if it didn't exist
"""
try:
redis = await self._get_redis()
result = await redis.delete(key)
return result > 0
except Exception as e:
logger.error(f"Error deleting Redis key {key}: {str(e)}")
return False
async def exists(self, key: str) -> bool:
"""
Check if a key exists in the cache.
Args:
key: Key to check
Returns:
True if the key exists, False otherwise
"""
try:
redis = await self._get_redis()
return await redis.exists(key) > 0
except Exception as e:
logger.error(f"Error checking Redis key {key}: {str(e)}")
return False
async def set_hash(self, key: str, field: str, value: Any) -> None:
"""
Store a value in a hash.
Args:
key: Hash key
field: Hash field
value: Value to store
"""
try:
redis = await self._get_redis()
# Convert dict/list to JSON string
if isinstance(value, (dict, list)):
value = json.dumps(value)
await redis.hset(key, field, value)
logger.debug(f"Set hash field: {key}:{field}")
except Exception as e:
logger.error(f"Error setting Redis hash {key}:{field}: {str(e)}")
async def get_hash(self, key: str, field: str) -> Optional[Any]:
"""
Retrieve a value from a hash.
Args:
key: Hash key
field: Hash field
Returns:
The stored value or None if not found
"""
try:
redis = await self._get_redis()
value = await redis.hget(key, field)
if value is None:
return None
try:
# Try to parse as JSON
return json.loads(value)
except json.JSONDecodeError:
# Return as-is if not JSON
return value
except Exception as e:
logger.error(f"Error getting Redis hash {key}:{field}: {str(e)}")
return None
async def delete_hash(self, key: str, field: str) -> bool:
"""
Remove a value from a hash.
Args:
key: Hash key
field: Hash field
Returns:
True if the value was removed, False if it didn't exist
"""
try:
redis = await self._get_redis()
result = await redis.hdel(key, field)
return result > 0
except Exception as e:
logger.error(f"Error deleting Redis hash {key}:{field}: {str(e)}")
return False
async def get_all_hash(self, key: str) -> Dict[str, Any]:
"""
Retrieve all values from a hash.
Args:
key: Hash key
Returns:
Dictionary with all hash values
"""
try:
redis = await self._get_redis()
result_dict = await redis.hgetall(key)
if not result_dict:
return {}
# Try to parse each value as JSON
parsed_dict = {}
for field, value in result_dict.items():
try:
parsed_dict[field] = json.loads(value)
except json.JSONDecodeError:
parsed_dict[field] = value
return parsed_dict
except Exception as e:
logger.error(f"Error getting all Redis hash fields for {key}: {str(e)}")
return {}
async def push_list(self, key: str, value: Any) -> int:
"""
Add a value to the end of a list.
Args:
key: List key
value: Value to add
Returns:
Size of the list after the addition
"""
try:
redis = await self._get_redis()
# Convert dict/list to JSON string
if isinstance(value, (dict, list)):
value = json.dumps(value)
return await redis.rpush(key, value)
except Exception as e:
logger.error(f"Error pushing to Redis list {key}: {str(e)}")
return 0
async def get_list(self, key: str, start: int = 0, end: int = -1) -> List[Any]:
"""
Retrieve values from a list.
Args:
key: List key
start: Initial index (inclusive)
end: Final index (inclusive), -1 for all
Returns:
List with the retrieved values
"""
try:
redis = await self._get_redis()
values = await redis.lrange(key, start, end)
if not values:
return []
# Try to parse each value as JSON
result = []
for value in values:
try:
result.append(json.loads(value))
except json.JSONDecodeError:
result.append(value)
return result
except Exception as e:
logger.error(f"Error getting Redis list {key}: {str(e)}")
return []
async def expire(self, key: str, ttl: int) -> bool:
"""
Set a time-to-live for a key.
Args:
key: Key
ttl: Time-to-live in seconds
Returns:
True if the key exists and the TTL was set, False otherwise
"""
try:
redis = await self._get_redis()
return await redis.expire(key, ttl)
except Exception as e:
logger.error(f"Error setting expire for Redis key {key}: {str(e)}")
return False
async def clear(self) -> None:
"""
Clear the entire cache.
Warning: This is a destructive operation and will remove all data.
Only use in development/test environments.
"""
try:
redis = await self._get_redis()
await redis.flushdb()
logger.warning("Redis database flushed - all data cleared")
except Exception as e:
logger.error(f"Error clearing Redis database: {str(e)}")
async def keys(self, pattern: str = "*") -> List[str]:
"""
Retrieve keys that match a pattern.
Args:
pattern: Glob pattern to filter keys
Returns:
List of keys that match the pattern
"""
try:
redis = await self._get_redis()
return await redis.keys(pattern)
except Exception as e:
logger.error(f"Error getting Redis keys with pattern {pattern}: {str(e)}")
return []

View File

@ -1,141 +0,0 @@
import uuid
import json
from typing import AsyncGenerator, Dict, Any
from fastapi import HTTPException
from datetime import datetime
from src.schemas.streaming import (
JSONRPCRequest,
TaskStatusUpdateEvent,
)
from src.services.agent_runner import run_agent
from src.services.service_providers import (
session_service,
artifacts_service,
memory_service,
)
from sqlalchemy.orm import Session
class StreamingService:
def __init__(self):
self.active_connections: Dict[str, Any] = {}
async def send_task_streaming(
self,
agent_id: str,
api_key: str,
message: str,
contact_id: str = None,
session_id: str = None,
db: Session = None,
) -> AsyncGenerator[str, None]:
"""
Starts the SSE event streaming for a task.
Args:
agent_id: Agent ID
api_key: API key for authentication
message: Initial message
contact_id: Contact ID (optional)
session_id: Session ID (optional)
db: Database session
Yields:
Formatted SSE events
"""
# Basic API key validation
if not api_key:
raise HTTPException(status_code=401, detail="API key is required")
# Generate unique IDs
task_id = contact_id or str(uuid.uuid4())
request_id = str(uuid.uuid4())
# Build JSON-RPC payload
payload = JSONRPCRequest(
id=request_id,
params={
"id": task_id,
"sessionId": session_id,
"message": {
"role": "user",
"parts": [{"type": "text", "text": message}],
},
},
)
# Register connection
self.active_connections[task_id] = {
"agent_id": agent_id,
"api_key": api_key,
"session_id": session_id,
}
try:
# Send start event
yield self._format_sse_event(
"status",
TaskStatusUpdateEvent(
state="working",
timestamp=datetime.now().isoformat(),
message=payload.params["message"],
).model_dump_json(),
)
# Execute the agent
result = await run_agent(
str(agent_id),
contact_id or task_id,
message,
session_service,
artifacts_service,
memory_service,
db,
session_id,
)
# Send the agent's response as a separate event
yield self._format_sse_event(
"message",
json.dumps(
{
"role": "agent",
"content": result,
"timestamp": datetime.now().isoformat(),
}
),
)
# Completion event
yield self._format_sse_event(
"status",
TaskStatusUpdateEvent(
state="completed",
timestamp=datetime.now().isoformat(),
).model_dump_json(),
)
except Exception as e:
# Error event
yield self._format_sse_event(
"status",
TaskStatusUpdateEvent(
state="failed",
timestamp=datetime.now().isoformat(),
error={"message": str(e)},
).model_dump_json(),
)
raise
finally:
# Clean connection
self.active_connections.pop(task_id, None)
def _format_sse_event(self, event_type: str, data: str) -> str:
"""Format an SSE event."""
return f"event: {event_type}\ndata: {data}\n\n"
async def close_connection(self, task_id: str):
"""Close a streaming connection."""
if task_id in self.active_connections:
self.active_connections.pop(task_id)

View File

@ -1,110 +1,28 @@
"""
A2A protocol utilities.
This module contains utility functions for the A2A protocol implementation.
"""
import logging
from typing import List, Optional, Any, Dict
from src.schemas.a2a import (
from src.schemas.a2a_types import (
ContentTypeNotSupportedError,
UnsupportedOperationError,
JSONRPCResponse,
UnsupportedOperationError,
)
logger = logging.getLogger(__name__)
def are_modalities_compatible(
server_output_modes: Optional[List[str]], client_output_modes: Optional[List[str]]
) -> bool:
"""
Check if server and client modalities are compatible.
Modalities are compatible if they are both non-empty
server_output_modes: list[str], client_output_modes: list[str]
):
"""Modalities are compatible if they are both non-empty
and there is at least one common element.
Args:
server_output_modes: List of output modes supported by the server
client_output_modes: List of output modes requested by the client
Returns:
True if compatible, False otherwise
"""
# If client doesn't specify modes, assume all are accepted
if client_output_modes is None or len(client_output_modes) == 0:
return True
# If server doesn't specify modes, assume all are supported
if server_output_modes is None or len(server_output_modes) == 0:
return True
# Check if there's at least one common mode
return any(mode in server_output_modes for mode in client_output_modes)
return any(x in server_output_modes for x in client_output_modes)
def new_incompatible_types_error(request_id: str) -> JSONRPCResponse:
"""
Create a JSON-RPC response for incompatible content types error.
Args:
request_id: The ID of the request that caused the error
Returns:
JSON-RPC response with ContentTypeNotSupportedError
"""
def new_incompatible_types_error(request_id):
return JSONRPCResponse(id=request_id, error=ContentTypeNotSupportedError())
def new_not_implemented_error(request_id: str) -> JSONRPCResponse:
"""
Create a JSON-RPC response for unimplemented operation error.
Args:
request_id: The ID of the request that caused the error
Returns:
JSON-RPC response with UnsupportedOperationError
"""
def new_not_implemented_error(request_id):
return JSONRPCResponse(id=request_id, error=UnsupportedOperationError())
def create_task_id(agent_id: str, session_id: str, timestamp: str = None) -> str:
"""
Create a unique task ID for an agent and session.
Args:
agent_id: The ID of the agent
session_id: The ID of the session
timestamp: Optional timestamp to include in the ID
Returns:
A unique task ID
"""
import uuid
import time
timestamp = timestamp or str(int(time.time()))
unique = uuid.uuid4().hex[:8]
return f"{agent_id}_{session_id}_{timestamp}_{unique}"
def format_error_response(error: Exception, request_id: str = None) -> Dict[str, Any]:
"""
Format an exception as a JSON-RPC error response.
Args:
error: The exception to format
request_id: The ID of the request that caused the error
Returns:
JSON-RPC error response as dictionary
"""
from src.schemas.a2a import InternalError, JSONRPCResponse
error_response = JSONRPCResponse(
id=request_id, error=InternalError(message=str(error))
)
return error_response.model_dump(exclude_none=True)