diff --git a/src/api/a2a_routes.py b/src/api/a2a_routes.py index 92413356..6ab91da8 100644 --- a/src/api/a2a_routes.py +++ b/src/api/a2a_routes.py @@ -6,29 +6,18 @@ This module implements the standard A2A routes according to the specification. import uuid import logging -from fastapi import APIRouter, Depends, HTTPException, status, Header, Request +import json +from fastapi import APIRouter, Depends, Header, Request, HTTPException from sqlalchemy.orm import Session from starlette.responses import JSONResponse +from sse_starlette.sse import EventSourceResponse +from src.models.models import Agent from src.config.database import get_db -from src.services import agent_service -from src.services import ( - RedisCacheService, - AgentRunnerAdapter, - StreamingServiceAdapter, - create_agent_card_from_agent, +from src.services.a2a_task_manager import ( + A2ATaskManager, + A2AService, ) -from src.services.a2a_task_manager_service import A2ATaskManager -from src.services.a2a_server_service import A2AServer -from src.services.agent_runner import run_agent -from src.services.service_providers import ( - session_service, - artifacts_service, - memory_service, -) -from src.services.push_notification_service import push_notification_service -from src.services.push_notification_auth_service import push_notification_auth -from src.services.streaming_service import StreamingService logger = logging.getLogger(__name__) @@ -43,85 +32,27 @@ router = APIRouter( }, ) -streaming_service = StreamingService() -redis_cache_service = RedisCacheService() -streaming_adapter = StreamingServiceAdapter(streaming_service) -_task_manager_cache = {} -_agent_runner_cache = {} +def get_a2a_service(db: Session = Depends(get_db)): + task_manager = A2ATaskManager(db) + return A2AService(db, task_manager) -def get_agent_runner_adapter(db=None, reuse=True, agent_id=None): - """ - Get or create an agent runner adapter. +async def verify_api_key(db: Session, x_api_key: str) -> bool: + """Verifies the API key.""" + if not x_api_key: + raise HTTPException(status_code=401, detail="API key not provided") - Args: - db: Database session - reuse: Whether to reuse an existing instance - agent_id: Agent ID to use as cache key - - Returns: - Agent runner adapter instance - """ - cache_key = str(agent_id) if agent_id else "default" - - if reuse and cache_key in _agent_runner_cache: - adapter = _agent_runner_cache[cache_key] - - if db is not None: - adapter.db = db - return adapter - - adapter = AgentRunnerAdapter( - agent_runner_func=run_agent, - session_service=session_service, - artifacts_service=artifacts_service, - memory_service=memory_service, - db=db, + agent = ( + db.query(Agent) + .filter(Agent.config.has_key("api_key")) + .filter(Agent.config["api_key"].astext == x_api_key) + .first() ) - if reuse: - _agent_runner_cache[cache_key] = adapter - - return adapter - - -def get_task_manager(agent_id, db=None, reuse=True, operation_type="query"): - cache_key = str(agent_id) - - if operation_type == "query": - if cache_key in _task_manager_cache: - - task_manager = _task_manager_cache[cache_key] - task_manager.db = db - return task_manager - - return A2ATaskManager( - redis_cache=redis_cache_service, - agent_runner=None, - streaming_service=streaming_adapter, - push_notification_service=push_notification_service, - db=db, - ) - - if reuse and cache_key in _task_manager_cache: - task_manager = _task_manager_cache[cache_key] - task_manager.db = db - return task_manager - - # Create new - agent_runner_adapter = get_agent_runner_adapter( - db=db, reuse=reuse, agent_id=agent_id - ) - task_manager = A2ATaskManager( - redis_cache=redis_cache_service, - agent_runner=agent_runner_adapter, - streaming_service=streaming_adapter, - push_notification_service=push_notification_service, - db=db, - ) - _task_manager_cache[cache_key] = task_manager - return task_manager + if not agent: + raise HTTPException(status_code=401, detail="Invalid API key") + return True @router.post("/{agent_id}") @@ -130,156 +61,42 @@ async def process_a2a_request( request: Request, x_api_key: str = Header(None, alias="x-api-key"), db: Session = Depends(get_db), + a2a_service: A2AService = Depends(get_a2a_service), ): - """ - Main endpoint for processing JSON-RPC requests of the A2A protocol. + """Processes an A2A request.""" + # Verify the API key + if not verify_api_key(db, x_api_key): + raise HTTPException(status_code=401, detail="Invalid API key") - This endpoint processes all JSON-RPC methods of the A2A protocol, including: - - tasks/send: Sending tasks - - tasks/sendSubscribe: Sending tasks with streaming - - tasks/get: Querying task status - - tasks/cancel: Cancelling tasks - - tasks/pushNotification/set: Setting push notifications - - tasks/pushNotification/get: Querying push notification configurations - - tasks/resubscribe: Resubscribing to receive task updates - - Args: - agent_id: Agent ID - request: HTTP request with JSON-RPC payload - x_api_key: API key for authentication - db: Database session - - Returns: - JSON-RPC response or streaming (SSE) depending on the method - """ + # Process the request try: - try: - body = await request.json() - method = body.get("method", "unknown") - request_id = body.get("id") # Extract request ID to ensure it's preserved + request_body = await request.json() + result = await a2a_service.process_request(agent_id, request_body) - # Extrair sessionId do params se for tasks/send - session_id = None - if method == "tasks/send" and body.get("params"): - session_id = body.get("params", {}).get("sessionId") - logger.info(f"Extracted sessionId from request: {session_id}") + # If the response is a streaming response, return as EventSourceResponse + if hasattr(result, "__aiter__"): - is_query_request = method in [ - "tasks/get", - "tasks/cancel", - "tasks/pushNotification/get", - "tasks/resubscribe", - ] + async def event_generator(): + async for item in result: + if hasattr(item, "model_dump_json"): + yield {"data": item.model_dump_json(exclude_none=True)} + else: + yield {"data": json.dumps(item)} - reuse_components = is_query_request - - except Exception as e: - logger.error(f"Error reading request body: {e}") - return JSONResponse( - status_code=400, - content={ - "jsonrpc": "2.0", - "id": None, - "error": { - "code": -32700, - "message": f"Parse error: {str(e)}", - "data": None, - }, - }, - ) - - # Verify if the agent exists - agent = agent_service.get_agent(db, agent_id) - if agent is None: - logger.warning(f"Agent not found: {agent_id}") - return JSONResponse( - status_code=404, - content={ - "jsonrpc": "2.0", - "id": request_id, # Use the extracted request ID - "error": {"code": 404, "message": "Agent not found", "data": None}, - }, - ) - - # Verify API key - agent_config = agent.config - - if x_api_key and agent_config.get("api_key") != x_api_key: - logger.warning(f"Invalid API Key for agent {agent_id}") - return JSONResponse( - status_code=401, - content={ - "jsonrpc": "2.0", - "id": request_id, # Use the extracted request ID - "error": {"code": 401, "message": "Invalid API key", "data": None}, - }, - ) - - a2a_task_manager = get_task_manager( - agent_id, - db=db, - reuse=reuse_components, - operation_type="query" if is_query_request else "execution", - ) - a2a_server = A2AServer(task_manager=a2a_task_manager) - - agent_card = create_agent_card_from_agent(agent, db) - a2a_server.agent_card = agent_card - - # Verify JSON-RPC format - if not body.get("jsonrpc") or body.get("jsonrpc") != "2.0": - logger.error(f"Invalid JSON-RPC format: {body.get('jsonrpc')}") - return JSONResponse( - status_code=400, - content={ - "jsonrpc": "2.0", - "id": request_id, # Use the extracted request ID - "error": { - "code": -32600, - "message": "Invalid Request: jsonrpc must be '2.0'", - "data": None, - }, - }, - ) - - if not body.get("method"): - logger.error("Method not specified in request") - return JSONResponse( - status_code=400, - content={ - "jsonrpc": "2.0", - "id": request_id, # Use the extracted request ID - "error": { - "code": -32600, - "message": "Invalid Request: method is required", - "data": None, - }, - }, - ) - - # Processar a requisição normalmente - return await a2a_server.process_request(request, agent_id=str(agent_id), db=db) + return EventSourceResponse(event_generator()) + # Otherwise, return as JSONResponse + if hasattr(result, "model_dump"): + return JSONResponse(result.model_dump(exclude_none=True)) + return JSONResponse(result) except Exception as e: - logger.error(f"Error processing A2A request: {str(e)}", exc_info=True) - # Try to extract request ID from the body, if available - request_id = None - try: - body = await request.json() - request_id = body.get("id") - except: - pass - + logger.error(f"Error processing A2A request: {e}") return JSONResponse( status_code=500, content={ "jsonrpc": "2.0", - "id": request_id, # Use the extracted request ID or None - "error": { - "code": -32603, - "message": "Internal server error", - "data": {"detail": str(e)}, - }, + "id": None, + "error": {"code": -32603, "message": "Internal server error"}, }, ) @@ -289,81 +106,17 @@ async def get_agent_card( agent_id: uuid.UUID, request: Request, db: Session = Depends(get_db), + a2a_service: A2AService = Depends(get_a2a_service), ): - """ - Endpoint to get the Agent Card in the .well-known format of the A2A protocol. - - This endpoint returns the agent information in the standard A2A format, - including capabilities, authentication information, and skills. - - Args: - agent_id: Agent ID - request: HTTP request - db: Database session - - Returns: - Agent Card in JSON format - """ + """Gets the agent card for the specified agent.""" try: - agent = agent_service.get_agent(db, agent_id) - if agent is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found" - ) - - agent_card = create_agent_card_from_agent(agent, db) - - a2a_task_manager = get_task_manager(agent_id, db=db, reuse=True) - a2a_server = A2AServer(task_manager=a2a_task_manager) - - # Configure the A2A server with the agent card - a2a_server.agent_card = agent_card - - # Use the A2A server to deliver the agent card, ensuring protocol compatibility - return await a2a_server.get_agent_card(request, db=db) - + agent_card = a2a_service.get_agent_card(agent_id) + if hasattr(agent_card, "model_dump"): + return JSONResponse(agent_card.model_dump(exclude_none=True)) + return JSONResponse(agent_card) except Exception as e: - logger.error(f"Error generating agent card: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error generating agent card", - ) - - -@router.get("/{agent_id}/.well-known/jwks.json") -async def get_jwks( - agent_id: uuid.UUID, - request: Request, - db: Session = Depends(get_db), -): - """ - Endpoint to get the public JWKS keys for verifying the authenticity - of push notifications. - - Clients can use these keys to verify the authenticity of received notifications. - - Args: - agent_id: Agent ID - request: HTTP request - db: Database session - - Returns: - JSON with the public keys in JWKS format - """ - try: - # Verify if the agent exists - agent = agent_service.get_agent(db, agent_id) - if agent is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found" - ) - - # Return the public keys - return push_notification_auth.handle_jwks_endpoint(request) - - except Exception as e: - logger.error(f"Error obtaining JWKS: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error obtaining JWKS", + logger.error(f"Error getting agent card: {e}") + return JSONResponse( + status_code=404, + content={"error": f"Agent not found: {str(e)}"}, ) diff --git a/src/schemas/a2a/__init__.py b/src/schemas/a2a/__init__.py deleted file mode 100644 index 35c904ff..00000000 --- a/src/schemas/a2a/__init__.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -A2A (Agent-to-Agent) schema package. - -This package contains Pydantic schema definitions for the A2A protocol. -""" - -from src.schemas.a2a.types import ( - TaskState, - TextPart, - FileContent, - FilePart, - DataPart, - Part, - Message, - TaskStatus, - Artifact, - Task, - TaskStatusUpdateEvent, - TaskArtifactUpdateEvent, - AuthenticationInfo, - PushNotificationConfig, - TaskIdParams, - TaskQueryParams, - TaskSendParams, - TaskPushNotificationConfig, - JSONRPCMessage, - JSONRPCRequest, - JSONRPCResponse, - JSONRPCError, - SendTaskRequest, - SendTaskResponse, - SendTaskStreamingRequest, - SendTaskStreamingResponse, - GetTaskRequest, - GetTaskResponse, - CancelTaskRequest, - CancelTaskResponse, - SetTaskPushNotificationRequest, - SetTaskPushNotificationResponse, - GetTaskPushNotificationRequest, - GetTaskPushNotificationResponse, - TaskResubscriptionRequest, - A2ARequest, - AgentProvider, - AgentCapabilities, - AgentAuthentication, - AgentSkill, - AgentCard, -) - -from src.schemas.a2a.exceptions import ( - JSONParseError, - InvalidRequestError, - MethodNotFoundError, - InvalidParamsError, - InternalError, - TaskNotFoundError, - TaskNotCancelableError, - PushNotificationNotSupportedError, - UnsupportedOperationError, - ContentTypeNotSupportedError, - A2AClientError, - A2AClientHTTPError, - A2AClientJSONError, - MissingAPIKeyError, -) - -from src.schemas.a2a.validators import ( - validate_base64, - validate_file_content, - validate_message_parts, - text_to_parts, - parts_to_text, -) - -__all__ = [ - # From types - "TaskState", - "TextPart", - "FileContent", - "FilePart", - "DataPart", - "Part", - "Message", - "TaskStatus", - "Artifact", - "Task", - "TaskStatusUpdateEvent", - "TaskArtifactUpdateEvent", - "AuthenticationInfo", - "PushNotificationConfig", - "TaskIdParams", - "TaskQueryParams", - "TaskSendParams", - "TaskPushNotificationConfig", - "JSONRPCMessage", - "JSONRPCRequest", - "JSONRPCResponse", - "JSONRPCError", - "SendTaskRequest", - "SendTaskResponse", - "SendTaskStreamingRequest", - "SendTaskStreamingResponse", - "GetTaskRequest", - "GetTaskResponse", - "CancelTaskRequest", - "CancelTaskResponse", - "SetTaskPushNotificationRequest", - "SetTaskPushNotificationResponse", - "GetTaskPushNotificationRequest", - "GetTaskPushNotificationResponse", - "TaskResubscriptionRequest", - "A2ARequest", - "AgentProvider", - "AgentCapabilities", - "AgentAuthentication", - "AgentSkill", - "AgentCard", - # From exceptions - "JSONParseError", - "InvalidRequestError", - "MethodNotFoundError", - "InvalidParamsError", - "InternalError", - "TaskNotFoundError", - "TaskNotCancelableError", - "PushNotificationNotSupportedError", - "UnsupportedOperationError", - "ContentTypeNotSupportedError", - "A2AClientError", - "A2AClientHTTPError", - "A2AClientJSONError", - "MissingAPIKeyError", - # From validators - "validate_base64", - "validate_file_content", - "validate_message_parts", - "text_to_parts", - "parts_to_text", -] diff --git a/src/schemas/a2a/exceptions.py b/src/schemas/a2a/exceptions.py deleted file mode 100644 index 19dce091..00000000 --- a/src/schemas/a2a/exceptions.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -A2A (Agent-to-Agent) protocol exception definitions. - -This module contains error types and exceptions for the A2A protocol. -""" - -from src.schemas.a2a.types import JSONRPCError - - -class JSONParseError(JSONRPCError): - """ - Error raised when JSON parsing fails. - """ - - code: int = -32700 - message: str = "Invalid JSON payload" - data: object | None = None - - -class InvalidRequestError(JSONRPCError): - """ - Error raised when request validation fails. - """ - - code: int = -32600 - message: str = "Request payload validation error" - data: object | None = None - - -class MethodNotFoundError(JSONRPCError): - """ - Error raised when the requested method is not found. - """ - - code: int = -32601 - message: str = "Method not found" - data: None = None - - -class InvalidParamsError(JSONRPCError): - """ - Error raised when the parameters are invalid. - """ - - code: int = -32602 - message: str = "Invalid parameters" - data: object | None = None - - -class InternalError(JSONRPCError): - """ - Error raised when an internal error occurs. - """ - - code: int = -32603 - message: str = "Internal error" - data: object | None = None - - -class TaskNotFoundError(JSONRPCError): - """ - Error raised when the requested task is not found. - """ - - code: int = -32001 - message: str = "Task not found" - data: None = None - - -class TaskNotCancelableError(JSONRPCError): - """ - Error raised when a task cannot be canceled. - """ - - code: int = -32002 - message: str = "Task cannot be canceled" - data: None = None - - -class PushNotificationNotSupportedError(JSONRPCError): - """ - Error raised when push notifications are not supported. - """ - - code: int = -32003 - message: str = "Push Notification is not supported" - data: None = None - - -class UnsupportedOperationError(JSONRPCError): - """ - Error raised when an operation is not supported. - """ - - code: int = -32004 - message: str = "This operation is not supported" - data: None = None - - -class ContentTypeNotSupportedError(JSONRPCError): - """ - Error raised when content types are incompatible. - """ - - code: int = -32005 - message: str = "Incompatible content types" - data: None = None - - -# Client exceptions - - -class A2AClientError(Exception): - """ - Base exception for A2A client errors. - """ - - pass - - -class A2AClientHTTPError(A2AClientError): - """ - Exception for HTTP errors in A2A client. - """ - - def __init__(self, status_code: int, message: str): - self.status_code = status_code - self.message = message - super().__init__(f"HTTP Error {status_code}: {message}") - - -class A2AClientJSONError(A2AClientError): - """ - Exception for JSON errors in A2A client. - """ - - def __init__(self, message: str): - self.message = message - super().__init__(f"JSON Error: {message}") - - -class MissingAPIKeyError(Exception): - """ - Exception for missing API key. - """ - - pass diff --git a/src/schemas/a2a/validators.py b/src/schemas/a2a/validators.py deleted file mode 100644 index 9ee32811..00000000 --- a/src/schemas/a2a/validators.py +++ /dev/null @@ -1,124 +0,0 @@ -""" -A2A (Agent-to-Agent) protocol validators. - -This module contains validators for the A2A protocol data. -""" - -from typing import List -import base64 -import re -from pydantic import ValidationError -import logging -from src.schemas.a2a.types import Part, TextPart, FilePart, DataPart, FileContent - -logger = logging.getLogger(__name__) - - -def validate_base64(value: str) -> bool: - """ - Validates if a string is valid base64. - - Args: - value: String to validate - - Returns: - True if valid base64, False otherwise - """ - try: - if not value: - return False - - # Check if the string has base64 characters only - pattern = r"^[A-Za-z0-9+/]+={0,2}$" - if not re.match(pattern, value): - return False - - # Try to decode - base64.b64decode(value) - return True - except Exception as e: - logger.warning(f"Invalid base64 string: {e}") - return False - - -def validate_file_content(file_content: FileContent) -> bool: - """ - Validates file content. - - Args: - file_content: FileContent to validate - - Returns: - True if valid, False otherwise - """ - try: - if file_content.bytes is not None: - return validate_base64(file_content.bytes) - elif file_content.uri is not None: - # Basic URL validation - pattern = r"^https?://.+" - return bool(re.match(pattern, file_content.uri)) - return False - except Exception as e: - logger.warning(f"Invalid file content: {e}") - return False - - -def validate_message_parts(parts: List[Part]) -> bool: - """ - Validates all parts in a message. - - Args: - parts: List of parts to validate - - Returns: - True if all parts are valid, False otherwise - """ - try: - for part in parts: - if isinstance(part, TextPart): - if not part.text or not isinstance(part.text, str): - return False - elif isinstance(part, FilePart): - if not validate_file_content(part.file): - return False - elif isinstance(part, DataPart): - if not part.data or not isinstance(part.data, dict): - return False - else: - return False - return True - except (ValidationError, Exception) as e: - logger.warning(f"Invalid message parts: {e}") - return False - - -def text_to_parts(text: str) -> List[Part]: - """ - Converts a plain text to a list of message parts. - - Args: - text: Plain text to convert - - Returns: - List containing a single TextPart - """ - return [TextPart(text=text)] - - -def parts_to_text(parts: List[Part]) -> str: - """ - Extracts text from a list of message parts. - - Args: - parts: List of parts to extract text from - - Returns: - Concatenated text from all text parts - """ - text = "" - for part in parts: - if isinstance(part, TextPart): - text += part.text - # Could add handling for other part types here - return text diff --git a/src/schemas/a2a/types.py b/src/schemas/a2a_types.py similarity index 60% rename from src/schemas/a2a/types.py rename to src/schemas/a2a_types.py index 9efcb7b3..3f52ef33 100644 --- a/src/schemas/a2a/types.py +++ b/src/schemas/a2a_types.py @@ -1,31 +1,20 @@ -""" -A2A (Agent-to-Agent) protocol type definitions. +from datetime import datetime +from enum import Enum +from typing import Annotated, Any, Literal, TypeVar +from uuid import uuid4 +from typing_extensions import Self -This module contains Pydantic schema definitions for the A2A protocol. -""" - -from typing import Union, Any, List, Optional, Annotated, Literal from pydantic import ( BaseModel, + ConfigDict, Field, TypeAdapter, field_serializer, model_validator, - ConfigDict, ) -from datetime import datetime -from uuid import uuid4 -from enum import Enum -from typing_extensions import Self class TaskState(str, Enum): - """ - Enum for the state of a task in the A2A protocol. - - States follow the A2A protocol specification. - """ - SUBMITTED = "submitted" WORKING = "working" INPUT_REQUIRED = "input-required" @@ -36,22 +25,12 @@ class TaskState(str, Enum): class TextPart(BaseModel): - """ - Represents a text part in a message. - """ - type: Literal["text"] = "text" text: str metadata: dict[str, Any] | None = None class FileContent(BaseModel): - """ - Represents file content in a file part. - - Either bytes or uri must be provided, but not both. - """ - name: str | None = None mimeType: str | None = None bytes: str | None = None @@ -59,9 +38,6 @@ class FileContent(BaseModel): @model_validator(mode="after") def check_content(self) -> Self: - """ - Validates that either bytes or uri is present, but not both. - """ if not (self.bytes or self.uri): raise ValueError("Either 'bytes' or 'uri' must be present in the file data") if self.bytes and self.uri: @@ -72,65 +48,40 @@ class FileContent(BaseModel): class FilePart(BaseModel): - """ - Represents a file part in a message. - """ - type: Literal["file"] = "file" file: FileContent metadata: dict[str, Any] | None = None class DataPart(BaseModel): - """ - Represents a data part in a message. - """ - type: Literal["data"] = "data" data: dict[str, Any] metadata: dict[str, Any] | None = None -Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator="type")] +Part = Annotated[TextPart | FilePart | DataPart, Field(discriminator="type")] class Message(BaseModel): - """ - Represents a message in the A2A protocol. - - A message consists of a role and one or more parts. - """ - role: Literal["user", "agent"] - parts: List[Part] + parts: list[Part] metadata: dict[str, Any] | None = None class TaskStatus(BaseModel): - """ - Represents the status of a task. - """ - state: TaskState message: Message | None = None timestamp: datetime = Field(default_factory=datetime.now) @field_serializer("timestamp") def serialize_dt(self, dt: datetime, _info): - """ - Serializes datetime to ISO format. - """ return dt.isoformat() class Artifact(BaseModel): - """ - Represents an artifact produced by an agent. - """ - name: str | None = None description: str | None = None - parts: List[Part] + parts: list[Part] metadata: dict[str, Any] | None = None index: int = 0 append: bool | None = None @@ -138,23 +89,15 @@ class Artifact(BaseModel): class Task(BaseModel): - """ - Represents a task in the A2A protocol. - """ - id: str sessionId: str | None = None status: TaskStatus - artifacts: List[Artifact] | None = None - history: List[Message] | None = None + artifacts: list[Artifact] | None = None + history: list[Message] | None = None metadata: dict[str, Any] | None = None class TaskStatusUpdateEvent(BaseModel): - """ - Represents a task status update event. - """ - id: str status: TaskStatus final: bool = False @@ -162,295 +105,234 @@ class TaskStatusUpdateEvent(BaseModel): class TaskArtifactUpdateEvent(BaseModel): - """ - Represents a task artifact update event. - """ - id: str artifact: Artifact metadata: dict[str, Any] | None = None class AuthenticationInfo(BaseModel): - """ - Represents authentication information for push notifications. - """ - model_config = ConfigDict(extra="allow") - schemes: List[str] + schemes: list[str] credentials: str | None = None class PushNotificationConfig(BaseModel): - """ - Represents push notification configuration. - """ - url: str token: str | None = None authentication: AuthenticationInfo | None = None class TaskIdParams(BaseModel): - """ - Represents parameters for identifying a task. - """ - id: str metadata: dict[str, Any] | None = None class TaskQueryParams(TaskIdParams): - """ - Represents parameters for querying a task. - """ - historyLength: int | None = None class TaskSendParams(BaseModel): - """ - Represents parameters for sending a task. - """ - id: str sessionId: str = Field(default_factory=lambda: uuid4().hex) message: Message - acceptedOutputModes: Optional[List[str]] = None + acceptedOutputModes: list[str] | None = None pushNotification: PushNotificationConfig | None = None historyLength: int | None = None metadata: dict[str, Any] | None = None - agentId: str = "" class TaskPushNotificationConfig(BaseModel): - """ - Represents push notification configuration for a task. - """ - id: str pushNotificationConfig: PushNotificationConfig -# RPC Messages +## RPC Messages class JSONRPCMessage(BaseModel): - """ - Base class for JSON-RPC messages. - """ - jsonrpc: Literal["2.0"] = "2.0" id: int | str | None = Field(default_factory=lambda: uuid4().hex) class JSONRPCRequest(JSONRPCMessage): - """ - Represents a JSON-RPC request. - """ - method: str params: dict[str, Any] | None = None class JSONRPCError(BaseModel): - """ - Represents a JSON-RPC error. - """ - code: int message: str data: Any | None = None class JSONRPCResponse(JSONRPCMessage): - """ - Represents a JSON-RPC response. - """ - result: Any | None = None error: JSONRPCError | None = None class SendTaskRequest(JSONRPCRequest): - """ - Represents a request to send a task. - """ - method: Literal["tasks/send"] = "tasks/send" params: TaskSendParams class SendTaskResponse(JSONRPCResponse): - """ - Represents a response to a send task request. - """ - result: Task | None = None class SendTaskStreamingRequest(JSONRPCRequest): - """ - Represents a request to send a task with streaming. - """ - method: Literal["tasks/sendSubscribe"] = "tasks/sendSubscribe" params: TaskSendParams class SendTaskStreamingResponse(JSONRPCResponse): - """ - Represents a streaming response to a send task request. - """ - result: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None class GetTaskRequest(JSONRPCRequest): - """ - Represents a request to get task information. - """ - method: Literal["tasks/get"] = "tasks/get" params: TaskQueryParams class GetTaskResponse(JSONRPCResponse): - """ - Represents a response to a get task request. - """ - result: Task | None = None class CancelTaskRequest(JSONRPCRequest): - """ - Represents a request to cancel a task. - """ - method: Literal["tasks/cancel",] = "tasks/cancel" params: TaskIdParams class CancelTaskResponse(JSONRPCResponse): - """ - Represents a response to a cancel task request. - """ - result: Task | None = None class SetTaskPushNotificationRequest(JSONRPCRequest): - """ - Represents a request to set push notification for a task. - """ - method: Literal["tasks/pushNotification/set",] = "tasks/pushNotification/set" params: TaskPushNotificationConfig class SetTaskPushNotificationResponse(JSONRPCResponse): - """ - Represents a response to a set push notification request. - """ - result: TaskPushNotificationConfig | None = None class GetTaskPushNotificationRequest(JSONRPCRequest): - """ - Represents a request to get push notification configuration for a task. - """ - method: Literal["tasks/pushNotification/get",] = "tasks/pushNotification/get" params: TaskIdParams class GetTaskPushNotificationResponse(JSONRPCResponse): - """ - Represents a response to a get push notification request. - """ - result: TaskPushNotificationConfig | None = None class TaskResubscriptionRequest(JSONRPCRequest): - """ - Represents a request to resubscribe to a task. - """ - method: Literal["tasks/resubscribe",] = "tasks/resubscribe" params: TaskIdParams -# TypeAdapter for discriminating A2A requests by method A2ARequest = TypeAdapter( Annotated[ - Union[ - SendTaskRequest, - GetTaskRequest, - CancelTaskRequest, - SetTaskPushNotificationRequest, - GetTaskPushNotificationRequest, - TaskResubscriptionRequest, - SendTaskStreamingRequest, - ], + SendTaskRequest + | GetTaskRequest + | CancelTaskRequest + | SetTaskPushNotificationRequest + | GetTaskPushNotificationRequest + | TaskResubscriptionRequest + | SendTaskStreamingRequest, Field(discriminator="method"), ] ) +## Error types -# Agent Card schemas + +class JSONParseError(JSONRPCError): + code: int = -32700 + message: str = "Invalid JSON payload" + data: Any | None = None + + +class InvalidRequestError(JSONRPCError): + code: int = -32600 + message: str = "Request payload validation error" + data: Any | None = None + + +class MethodNotFoundError(JSONRPCError): + code: int = -32601 + message: str = "Method not found" + data: None = None + + +class InvalidParamsError(JSONRPCError): + code: int = -32602 + message: str = "Invalid parameters" + data: Any | None = None + + +class InternalError(JSONRPCError): + code: int = -32603 + message: str = "Internal error" + data: Any | None = None + + +class TaskNotFoundError(JSONRPCError): + code: int = -32001 + message: str = "Task not found" + data: None = None + + +class TaskNotCancelableError(JSONRPCError): + code: int = -32002 + message: str = "Task cannot be canceled" + data: None = None + + +class PushNotificationNotSupportedError(JSONRPCError): + code: int = -32003 + message: str = "Push Notification is not supported" + data: None = None + + +class UnsupportedOperationError(JSONRPCError): + code: int = -32004 + message: str = "This operation is not supported" + data: None = None + + +class ContentTypeNotSupportedError(JSONRPCError): + code: int = -32005 + message: str = "Incompatible content types" + data: None = None class AgentProvider(BaseModel): - """ - Represents the provider of an agent. - """ - organization: str url: str | None = None class AgentCapabilities(BaseModel): - """ - Represents the capabilities of an agent. - """ - streaming: bool = False pushNotifications: bool = False stateTransitionHistory: bool = False class AgentAuthentication(BaseModel): - """ - Represents the authentication requirements for an agent. - """ - - schemes: List[str] + schemes: list[str] credentials: str | None = None class AgentSkill(BaseModel): - """ - Represents a skill of an agent. - """ - id: str name: str description: str | None = None - tags: List[str] | None = None - examples: List[str] | None = None - inputModes: List[str] | None = None - outputModes: List[str] | None = None + tags: list[str] | None = None + examples: list[str] | None = None + inputModes: list[str] | None = None + outputModes: list[str] | None = None class AgentCard(BaseModel): - """ - Represents an agent card in the A2A protocol. - """ - name: str description: str | None = None url: str @@ -459,6 +341,27 @@ class AgentCard(BaseModel): documentationUrl: str | None = None capabilities: AgentCapabilities authentication: AgentAuthentication | None = None - defaultInputModes: List[str] = ["text"] - defaultOutputModes: List[str] = ["text"] - skills: List[AgentSkill] + defaultInputModes: list[str] = ["text"] + defaultOutputModes: list[str] = ["text"] + skills: list[AgentSkill] + + +class A2AClientError(Exception): + pass + + +class A2AClientHTTPError(A2AClientError): + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + super().__init__(f"HTTP Error {status_code}: {message}") + + +class A2AClientJSONError(A2AClientError): + def __init__(self, message: str): + self.message = message + super().__init__(f"JSON Error: {message}") + + +class MissingAPIKeyError(Exception): + """Exception for missing API key.""" diff --git a/src/services/__init__.py b/src/services/__init__.py index 9434fe1e..255943f4 100644 --- a/src/services/__init__.py +++ b/src/services/__init__.py @@ -1,9 +1 @@ from .agent_runner import run_agent -from .redis_cache_service import RedisCacheService -from .a2a_task_manager_service import A2ATaskManager -from .a2a_server_service import A2AServer -from .a2a_integration_service import ( - AgentRunnerAdapter, - StreamingServiceAdapter, - create_agent_card_from_agent, -) diff --git a/src/services/a2a_agent.py b/src/services/a2a_agent.py index b57eb573..135d73bc 100644 --- a/src/services/a2a_agent.py +++ b/src/services/a2a_agent.py @@ -7,7 +7,7 @@ import json import asyncio import time -from src.schemas.a2a.types import ( +from src.schemas.a2a_types import ( GetTaskRequest, SendTaskRequest, Message, diff --git a/src/services/a2a_integration_service.py b/src/services/a2a_integration_service.py deleted file mode 100644 index 17a386b9..00000000 --- a/src/services/a2a_integration_service.py +++ /dev/null @@ -1,507 +0,0 @@ -""" -A2A Integration Service. - -This service provides adapters to integrate existing services with the A2A protocol. -""" - -import json -import logging -import uuid -from datetime import datetime -from typing import Any, Dict, List, Optional, AsyncIterable - -from src.schemas.a2a import ( - AgentCard, - AgentCapabilities, - AgentProvider, - Artifact, - Message, - TaskArtifactUpdateEvent, - TaskStatus, - TaskStatusUpdateEvent, - TextPart, -) - -logger = logging.getLogger(__name__) - - -class AgentRunnerAdapter: - """ - Adapter for integrating the existing agent runner with the A2A protocol. - """ - - def __init__( - self, - agent_runner_func, - session_service=None, - artifacts_service=None, - memory_service=None, - db=None, - ): - """ - Initialize the adapter. - - Args: - agent_runner_func: The agent runner function (e.g., run_agent) - session_service: Session service for message history - artifacts_service: Artifacts service for artifact history - memory_service: Memory service for agent memory - db: Database session - """ - self.agent_runner_func = agent_runner_func - self.session_service = session_service - self.artifacts_service = artifacts_service - self.memory_service = memory_service - self.db = db - - async def get_supported_modes(self) -> List[str]: - """ - Get the supported output modes for the agent. - - Returns: - List of supported output modes - """ - # Default modes, can be extended based on agent configuration - return ["text", "application/json"] - - async def run_agent( - self, - agent_id: str, - message: str, - session_id: Optional[str] = None, - task_id: Optional[str] = None, - db=None, - ) -> Dict[str, Any]: - """ - Run the agent with the given message. - - Args: - agent_id: ID of the agent to run - message: User message to process - session_id: Optional session ID for conversation context - task_id: Optional task ID for tracking - db: Database session - - Returns: - Dictionary with the agent's response - """ - - try: - # Use the existing agent runner function - # Usar o session_id fornecido, ou gerar um novo - session_id = session_id or str(uuid.uuid4()) - task_id = task_id or str(uuid.uuid4()) - - # Use the provided db or fallback to self.db - db_session = db if db is not None else self.db - - response_text = await self.agent_runner_func( - agent_id=agent_id, - contact_id=task_id, - message=message, - session_service=self.session_service, - artifacts_service=self.artifacts_service, - memory_service=self.memory_service, - db=db_session, - session_id=session_id, - ) - - # Format the response to include both the A2A-compatible message - # and the artifact format to match the Google A2A implementation - # Nota: O formato dos artifacts é simplificado para compatibilidade com Google A2A - message_obj = { - "role": "agent", - "parts": [{"type": "text", "text": response_text}], - } - - # Formato de artefato compatível com Google A2A - artifact_obj = { - "parts": [{"type": "text", "text": response_text}], - "index": 0, - } - - return { - "status": "success", - "content": response_text, - "session_id": session_id, - "task_id": task_id, - "timestamp": datetime.now().isoformat(), - "message": message_obj, - "artifact": artifact_obj, - } - - except Exception as e: - logger.error(f"[AGENT-RUNNER] Error running agent: {e}", exc_info=True) - return { - "status": "error", - "error": str(e), - "session_id": session_id, - "task_id": task_id, - "timestamp": datetime.now().isoformat(), - "message": { - "role": "agent", - "parts": [{"type": "text", "text": f"Error: {str(e)}"}], - }, - } - - async def cancel_task(self, task_id: str) -> bool: - """ - Cancel a running task. - - Args: - task_id: ID of the task to cancel - - Returns: - True if successfully canceled, False otherwise - """ - # Currently, the agent runner doesn't support cancellation - # This is a placeholder for future implementation - logger.warning(f"Task cancellation not implemented for task {task_id}") - return False - - -class StreamingServiceAdapter: - """ - Adapter for integrating the existing streaming service with the A2A protocol. - """ - - def __init__(self, streaming_service): - """ - Initialize the adapter. - - Args: - streaming_service: The streaming service instance - """ - self.streaming_service = streaming_service - - async def stream_agent_response( - self, - agent_id: str, - message: str, - api_key: str, - session_id: Optional[str] = None, - task_id: Optional[str] = None, - db=None, - ) -> AsyncIterable[str]: - """ - Stream the agent's response as A2A events. - - Args: - agent_id: ID of the agent - message: User message to process - api_key: API key for authentication - session_id: Optional session ID for conversation context - task_id: Optional task ID for tracking - db: Database session - - Yields: - A2A event objects as JSON strings for SSE (Server-Sent Events) - """ - task_id = task_id or str(uuid.uuid4()) - logger.info(f"Starting streaming response for task {task_id}") - - # Set working status event - working_status = TaskStatus( - state="working", - timestamp=datetime.now(), - message=Message( - role="agent", parts=[TextPart(text="Processing your request...")] - ), - ) - - status_event = TaskStatusUpdateEvent( - id=task_id, status=working_status, final=False - ) - yield json.dumps(status_event.model_dump()) - - content_buffer = "" - final_sent = False - has_error = False - - # Stream from the existing streaming service - try: - logger.info(f"Setting up streaming for agent {agent_id}, task {task_id}") - # To streaming, we use task_id as contact_id - contact_id = task_id - - last_event_time = datetime.now() - heartbeat_interval = 20 - - async for event in self.streaming_service.send_task_streaming( - agent_id=agent_id, - api_key=api_key, - message=message, - contact_id=contact_id, - session_id=session_id, - db=db, - ): - last_event_time = datetime.now() - - # Process the streaming event format - event_data = event.get("data", "{}") - try: - logger.info(f"Processing event data: {event_data[:100]}...") - data = json.loads(event_data) - - # Extract content - if "delta" in data and data["delta"].get("content"): - content = data["delta"]["content"] - content_buffer += content - logger.info(f"Received content chunk: {content[:50]}...") - - # Create artifact update event - artifact = Artifact( - name="response", - parts=[TextPart(text=content)], - index=0, - append=True, - lastChunk=False, - ) - - artifact_event = TaskArtifactUpdateEvent( - id=task_id, artifact=artifact - ) - yield json.dumps(artifact_event.model_dump()) - - # Check if final event - if data.get("done", False) and not final_sent: - logger.info(f"Received final event for task {task_id}") - # Create completed status event - completed_status = TaskStatus( - state="completed", - timestamp=datetime.now(), - message=Message( - role="agent", - parts=[ - TextPart(text=content_buffer or "Task completed") - ], - ), - ) - - # Final artifact with full content - final_artifact = Artifact( - name="response", - parts=[TextPart(text=content_buffer)], - index=0, - append=False, - lastChunk=True, - ) - - # Send the final artifact - final_artifact_event = TaskArtifactUpdateEvent( - id=task_id, artifact=final_artifact - ) - - yield json.dumps(final_artifact_event.model_dump()) - - # Send the completed status - final_status_event = TaskStatusUpdateEvent( - id=task_id, - status=completed_status, - final=True, - ) - - yield json.dumps(final_status_event.model_dump()) - - final_sent = True - - except json.JSONDecodeError as e: - logger.warning( - f"Received non-JSON event data: {e}. Data: {event_data[:100]}..." - ) - # Handle non-JSON events - simply add to buffer as text - if isinstance(event_data, str): - content_buffer += event_data - - # Create artifact update event - artifact = Artifact( - name="response", - parts=[TextPart(text=event_data)], - index=0, - append=True, - lastChunk=False, - ) - - artifact_event = TaskArtifactUpdateEvent( - id=task_id, artifact=artifact - ) - - yield json.dumps(artifact_event.model_dump()) - elif isinstance(event_data, dict): - # Try to extract text from the dictionary - text_value = str(event_data) - content_buffer += text_value - - artifact = Artifact( - name="response", - parts=[TextPart(text=text_value)], - index=0, - append=True, - lastChunk=False, - ) - - artifact_event = TaskArtifactUpdateEvent( - id=task_id, artifact=artifact - ) - - yield json.dumps(artifact_event.model_dump()) - - # Send heartbeat/keep-alive to keep the SSE connection open - now = datetime.now() - if (now - last_event_time).total_seconds() > heartbeat_interval: - logger.info(f"Sending heartbeat for task {task_id}") - # Sending keep-alive event as a "working" status event - working_heartbeat = TaskStatus( - state="working", - timestamp=now, - message=Message( - role="agent", parts=[TextPart(text="Still processing...")] - ), - ) - heartbeat_event = TaskStatusUpdateEvent( - id=task_id, status=working_heartbeat, final=False - ) - yield json.dumps(heartbeat_event.model_dump()) - last_event_time = now - - # Ensure we send a final event if not already sent - if not final_sent: - logger.info( - f"Stream completed for task {task_id}, sending final status" - ) - # Create completed status event - completed_status = TaskStatus( - state="completed", - timestamp=datetime.now(), - message=Message( - role="agent", - parts=[TextPart(text=content_buffer or "Task completed")], - ), - ) - - # Send the completed status - final_event = TaskStatusUpdateEvent( - id=task_id, status=completed_status, final=True - ) - yield json.dumps(final_event.model_dump()) - - except Exception as e: - has_error = True - logger.error(f"Error in streaming for task {task_id}: {e}", exc_info=True) - - # Create failed status event - failed_status = TaskStatus( - state="failed", - timestamp=datetime.now(), - message=Message( - role="agent", - parts=[ - TextPart( - text=f"Error during streaming: {str(e)}. Partial response: {content_buffer[:200] if content_buffer else 'No content received'}" - ) - ], - ), - ) - - error_event = TaskStatusUpdateEvent( - id=task_id, status=failed_status, final=True - ) - yield json.dumps(error_event.model_dump()) - - finally: - # Ensure we send a final event to properly close the connection - if not final_sent and not has_error: - logger.info(f"Stream finalizing for task {task_id} via finally block") - try: - # Create completed status event - completed_status = TaskStatus( - state="completed", - timestamp=datetime.now(), - message=Message( - role="agent", - parts=[ - TextPart( - text=content_buffer or "Task completed (forced end)" - ) - ], - ), - ) - - # Send the completed status - final_event = TaskStatusUpdateEvent( - id=task_id, status=completed_status, final=True - ) - yield json.dumps(final_event.model_dump()) - except Exception as final_error: - logger.error( - f"Error sending final event in finally block: {final_error}" - ) - - logger.info(f"Streaming completed for task {task_id}") - - -def create_agent_card_from_agent(agent, db) -> AgentCard: - """ - Create an A2A agent card from an agent model. - - Args: - agent: The agent model from the database - db: Database session - - Returns: - A2A AgentCard object - """ - import os - from src.api.agent_routes import format_agent_tools - import asyncio - - # Extract agent configuration - agent_config = agent.config - has_streaming = True # Assuming streaming is always supported - has_push = True # Assuming push notifications are supported - - # Format tools as skills - try: - # We use a different approach to handle the asynchronous function - mcp_servers = agent_config.get("mcp_servers", []) - - # We create a new thread to execute the asynchronous function - import concurrent.futures - - def run_async(coro): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - result = loop.run_until_complete(coro) - loop.close() - return result - - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(run_async, format_agent_tools(mcp_servers, db)) - skills = future.result() - except Exception as e: - logger.error(f"Error formatting agent tools: {e}") - skills = [] - - # Create agent card - return AgentCard( - name=agent.name, - description=agent.description, - url=f"{os.getenv('API_URL', '')}/api/v1/a2a/{agent.id}", - provider=AgentProvider( - organization=os.getenv("ORGANIZATION_NAME", ""), - url=os.getenv("ORGANIZATION_URL", ""), - ), - version=os.getenv("API_VERSION", "1.0.0"), - capabilities=AgentCapabilities( - streaming=has_streaming, - pushNotifications=has_push, - stateTransitionHistory=True, - ), - authentication={ - "schemes": ["apiKey"], - "credentials": "x-api-key", - }, - defaultInputModes=["text", "application/json"], - defaultOutputModes=["text", "application/json"], - skills=skills, - ) diff --git a/src/services/a2a_server_service.py b/src/services/a2a_server_service.py deleted file mode 100644 index bddf1ac8..00000000 --- a/src/services/a2a_server_service.py +++ /dev/null @@ -1,708 +0,0 @@ -""" -Server A2A and task manager for the A2A protocol. - -This module implements a JSON-RPC compatible server for the A2A protocol, -that manages agent tasks, streaming events and push notifications. -""" - -import json -import logging -from datetime import datetime -from typing import ( - Any, - Dict, - List, - Optional, - AsyncGenerator, - Union, - AsyncIterable, -) -from fastapi import Request -from fastapi.responses import JSONResponse, StreamingResponse, Response -from sqlalchemy.orm import Session - -from src.schemas.a2a.types import A2ARequest -from src.services.a2a_integration_service import ( - AgentRunnerAdapter, - StreamingServiceAdapter, -) -from src.services.session_service import get_session_events -from src.services.redis_cache_service import RedisCacheService -from src.schemas.a2a.types import ( - SendTaskRequest, - SendTaskStreamingRequest, - GetTaskRequest, - CancelTaskRequest, - SetTaskPushNotificationRequest, - GetTaskPushNotificationRequest, - TaskResubscriptionRequest, -) - -logger = logging.getLogger(__name__) - - -class A2ATaskManager: - """ - Task manager for the A2A protocol. - - This class manages the lifecycle of tasks, including: - - Task execution - - Streaming of events - - Push notifications - - Status querying - - Cancellation - """ - - def __init__( - self, - redis_cache: RedisCacheService, - agent_runner: AgentRunnerAdapter, - streaming_service: StreamingServiceAdapter, - push_notification_service: Any = None, - ): - """ - Initialize the task manager. - - Args: - redis_cache: Cache service for storing task data - agent_runner: Adapter for agent execution - streaming_service: Adapter for event streaming - push_notification_service: Service for sending push notifications - """ - self.cache = redis_cache - self.agent_runner = agent_runner - self.streaming_service = streaming_service - self.push_notification_service = push_notification_service - self._running_tasks = {} - - async def on_send_task( - self, - task_id: str, - agent_id: str, - message: Dict[str, Any], - session_id: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - input_mode: str = "text", - output_modes: List[str] = ["text"], - db: Optional[Session] = None, - ) -> Dict[str, Any]: - """ - Process a request to send a task. - - Args: - task_id: Task ID - agent_id: Agent ID - message: User message - session_id: Session ID (optional) - metadata: Additional metadata (optional) - input_mode: Input mode (text, JSON, etc.) - output_modes: Supported output modes - db: Database session - - Returns: - Response with task result - """ - if not session_id: - session_id = f"{task_id}_{agent_id}" - - if not metadata: - metadata = {} - - # Update status to "submitted" - task_data = { - "id": task_id, - "sessionId": session_id, - "status": { - "state": "submitted", - "timestamp": datetime.now().isoformat(), - "message": None, - "error": None, - }, - "artifacts": [], - "history": [], - "metadata": metadata, - } - - # Store initial task data - await self.cache.set(f"task:{task_id}", task_data) - - # Check for push notification configurations - push_config = await self.cache.get(f"task:{task_id}:push") - if push_config and self.push_notification_service: - # Send initial notification - await self.push_notification_service.send_notification( - url=push_config["url"], - task_id=task_id, - state="submitted", - headers=push_config.get("headers", {}), - ) - - try: - # Update status to "running" - task_data["status"].update( - {"state": "running", "timestamp": datetime.now().isoformat()} - ) - await self.cache.set(f"task:{task_id}", task_data) - - # Notify "running" state - if push_config and self.push_notification_service: - await self.push_notification_service.send_notification( - url=push_config["url"], - task_id=task_id, - state="running", - headers=push_config.get("headers", {}), - ) - - # Extract user message - user_message = None - try: - user_message = message["parts"][0]["text"] - except (KeyError, IndexError): - user_message = "" - - # Execute the agent - response = await self.agent_runner.run_agent( - agent_id=agent_id, - task_id=task_id, - message=user_message, - session_id=session_id, - db=db, - ) - - # Check if the response is a dictionary (error) or a string (success) - if isinstance(response, dict) and response.get("status") == "error": - # Error response - final_response = f"Error: {response.get('error', 'Unknown error')}" - - # Update status to "failed" - task_data["status"].update( - { - "state": "failed", - "timestamp": datetime.now().isoformat(), - "error": { - "code": "AGENT_EXECUTION_ERROR", - "message": response.get("error", "Unknown error"), - }, - "message": { - "role": "system", - "parts": [{"type": "text", "text": final_response}], - }, - } - ) - - # Notify "failed" state - if push_config and self.push_notification_service: - await self.push_notification_service.send_notification( - url=push_config["url"], - task_id=task_id, - state="failed", - message={ - "role": "system", - "parts": [{"type": "text", "text": final_response}], - }, - headers=push_config.get("headers", {}), - ) - else: - # Success response - final_response = ( - response.get("content") if isinstance(response, dict) else response - ) - - # Update status to "completed" - task_data["status"].update( - { - "state": "completed", - "timestamp": datetime.now().isoformat(), - "message": { - "role": "agent", - "parts": [{"type": "text", "text": final_response}], - }, - } - ) - - # Add artifacts - if final_response: - task_data["artifacts"].append( - { - "type": "text", - "content": final_response, - "metadata": { - "generated_at": datetime.now().isoformat(), - "content_type": "text/plain", - }, - } - ) - - # Add history of messages - history_length = metadata.get("historyLength", 50) - try: - history_messages = get_session_events( - self.agent_runner.session_service, session_id - ) - history_messages = history_messages[-history_length:] - - formatted_history = [] - for event in history_messages: - if event.content and event.content.parts: - role = ( - "agent" - if event.content.role == "model" - else event.content.role - ) - formatted_history.append( - { - "role": role, - "parts": [ - {"type": "text", "text": part.text} - for part in event.content.parts - if part.text - ], - } - ) - - task_data["history"] = formatted_history - except Exception as e: - logger.error(f"Error processing history: {str(e)}") - - # Notify "completed" state - if push_config and self.push_notification_service: - await self.push_notification_service.send_notification( - url=push_config["url"], - task_id=task_id, - state="completed", - message={ - "role": "agent", - "parts": [{"type": "text", "text": final_response}], - }, - headers=push_config.get("headers", {}), - ) - - except Exception as e: - logger.error(f"Error executing task {task_id}: {str(e)}") - - # Update status to "failed" - task_data["status"].update( - { - "state": "failed", - "timestamp": datetime.now().isoformat(), - "error": {"code": "AGENT_EXECUTION_ERROR", "message": str(e)}, - } - ) - - # Notify "failed" state - if push_config and self.push_notification_service: - await self.push_notification_service.send_notification( - url=push_config["url"], - task_id=task_id, - state="failed", - message={ - "role": "system", - "parts": [{"type": "text", "text": str(e)}], - }, - headers=push_config.get("headers", {}), - ) - - # Store final result - await self.cache.set(f"task:{task_id}", task_data) - return task_data - - async def on_send_task_subscribe( - self, - task_id: str, - agent_id: str, - message: Dict[str, Any], - session_id: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - input_mode: str = "text", - output_modes: List[str] = ["text"], - db: Optional[Session] = None, - ) -> AsyncGenerator[str, None]: - """ - Process a request to send a task with streaming. - - Args: - task_id: Task ID - agent_id: Agent ID - message: User message - session_id: Session ID (optional) - metadata: Additional metadata (optional) - input_mode: Input mode (text, JSON, etc.) - output_modes: Supported output modes - db: Database session - - Yields: - Streaming events in SSE (Server-Sent Events) format - """ - if not session_id: - session_id = f"{task_id}_{agent_id}" - - if not metadata: - metadata = {} - - # Extract user message - user_message = "" - try: - user_message = message["parts"][0]["text"] - except (KeyError, IndexError): - pass - - # Generate streaming events - async for event in self.streaming_service.stream_response( - agent_id=agent_id, - task_id=task_id, - message=user_message, - session_id=session_id, - db=db, - ): - yield event - - async def on_get_task(self, task_id: str) -> Dict[str, Any]: - """ - Query the status of a task by ID. - - Args: - task_id: Task ID - - Returns: - Current task status - - Raises: - Exception: If the task is not found - """ - task_data = await self.cache.get(f"task:{task_id}") - if not task_data: - raise Exception(f"Task {task_id} not found") - return task_data - - async def on_cancel_task(self, task_id: str) -> Dict[str, Any]: - """ - Cancel a running task. - - Args: - task_id: Task ID to be cancelled - - Returns: - Task status after cancellation - - Raises: - Exception: If the task is not found or cannot be cancelled - """ - task_data = await self.cache.get(f"task:{task_id}") - if not task_data: - raise Exception(f"Task {task_id} not found") - - # Check if the task is in a state that can be cancelled - current_state = task_data["status"]["state"] - if current_state not in ["submitted", "running"]: - raise Exception(f"Cannot cancel task in {current_state} state") - - # Cancel the task in the runner if it is running - running_task = self._running_tasks.get(task_id) - if running_task: - # Try to cancel the running task - if hasattr(running_task, "cancel"): - running_task.cancel() - - # Update status to "cancelled" - task_data["status"].update( - { - "state": "cancelled", - "timestamp": datetime.now().isoformat(), - "message": { - "role": "system", - "parts": [{"type": "text", "text": "Task cancelled by user"}], - }, - } - ) - - # Update cache - await self.cache.set(f"task:{task_id}", task_data) - - # Send push notification if configured - push_config = await self.cache.get(f"task:{task_id}:push") - if push_config and self.push_notification_service: - await self.push_notification_service.send_notification( - url=push_config["url"], - task_id=task_id, - state="cancelled", - message={ - "role": "system", - "parts": [{"type": "text", "text": "Task cancelled by user"}], - }, - headers=push_config.get("headers", {}), - ) - - return task_data - - async def on_set_task_push_notification( - self, task_id: str, notification_config: Dict[str, Any] - ) -> Dict[str, Any]: - """ - Configure push notifications for a task. - - Args: - task_id: Task ID - notification_config: Notification configuration (URL and headers) - - Returns: - Updated configuration - """ - # Validate configuration - url = notification_config.get("url") - if not url: - raise ValueError("Push notification URL is required") - - headers = notification_config.get("headers", {}) - - # Store configuration - config = {"url": url, "headers": headers} - await self.cache.set(f"task:{task_id}:push", config) - - return config - - async def on_get_task_push_notification(self, task_id: str) -> Dict[str, Any]: - """ - Get the push notification configuration for a task. - - Args: - task_id: Task ID - - Returns: - Push notification configuration - - Raises: - Exception: If there is no configuration for the task - """ - config = await self.cache.get(f"task:{task_id}:push") - if not config: - raise Exception(f"No push notification configuration for task {task_id}") - return config - - -class A2AServer: - """ - A2A server compatible with JSON-RPC 2.0. - - This class processes JSON-RPC requests and forwards them to - the appropriate handlers in the A2ATaskManager. - """ - - def __init__(self, task_manager: A2ATaskManager, agent_card=None): - """ - Initialize the A2A server. - - Args: - task_manager: Task manager - agent_card: Agent card information - """ - self.task_manager = task_manager - self.agent_card = agent_card - - async def process_request( - self, - request: Request, - agent_id: Optional[str] = None, - db: Optional[Session] = None, - ) -> Union[Response, JSONResponse, StreamingResponse]: - """ - Process a JSON-RPC request. - - Args: - request: HTTP request - agent_id: Optional agent ID to inject into the request - db: Database session - - Returns: - Appropriate response (JSON or Streaming) - """ - try: - # Try to parse the JSON payload - try: - logger.info("Starting JSON-RPC request processing") - body = await request.json() - logger.info(f"Received JSON data: {json.dumps(body)}") - method = body.get("method", "unknown") - - # Validate the request using the A2A validator - json_rpc_request = A2ARequest.validate_python(body) - - original_db = self.task_manager.db - try: - # Set the db temporarily - if db is not None: - self.task_manager.db = db - - # Process the request - if isinstance(json_rpc_request, SendTaskRequest): - json_rpc_request.params.agentId = agent_id - result = await self.task_manager.on_send_task(json_rpc_request) - elif isinstance(json_rpc_request, SendTaskStreamingRequest): - json_rpc_request.params.agentId = agent_id - result = await self.task_manager.on_send_task_subscribe( - json_rpc_request - ) - elif isinstance(json_rpc_request, GetTaskRequest): - result = await self.task_manager.on_get_task(json_rpc_request) - elif isinstance(json_rpc_request, CancelTaskRequest): - result = await self.task_manager.on_cancel_task( - json_rpc_request - ) - elif isinstance(json_rpc_request, SetTaskPushNotificationRequest): - result = await self.task_manager.on_set_task_push_notification( - json_rpc_request - ) - elif isinstance(json_rpc_request, GetTaskPushNotificationRequest): - result = await self.task_manager.on_get_task_push_notification( - json_rpc_request - ) - elif isinstance(json_rpc_request, TaskResubscriptionRequest): - result = await self.task_manager.on_resubscribe_to_task( - json_rpc_request - ) - else: - logger.warning( - f"[SERVER] Request type not supported: {type(json_rpc_request)}" - ) - return JSONResponse( - status_code=400, - content={ - "jsonrpc": "2.0", - "id": body.get("id"), - "error": { - "code": -32601, - "message": "Method not found", - "data": { - "detail": f"Method not supported: {method}" - }, - }, - }, - ) - finally: - # Restore the original db - self.task_manager.db = original_db - - # Create appropriate response - return self._create_response(result) - - except json.JSONDecodeError as e: - # Error parsing JSON - logger.error(f"Error parsing JSON request: {str(e)}") - return JSONResponse( - status_code=400, - content={ - "jsonrpc": "2.0", - "id": None, - "error": { - "code": -32700, - "message": "Parse error", - "data": {"detail": str(e)}, - }, - }, - ) - except Exception as e: - # Other validation errors - logger.error(f"Error validating request: {str(e)}") - return JSONResponse( - status_code=400, - content={ - "jsonrpc": "2.0", - "id": body.get("id") if "body" in locals() else None, - "error": { - "code": -32600, - "message": "Invalid Request", - "data": {"detail": str(e)}, - }, - }, - ) - - except Exception as e: - logger.error(f"Error processing JSON-RPC request: {str(e)}", exc_info=True) - return JSONResponse( - status_code=500, - content={ - "jsonrpc": "2.0", - "id": None, - "error": { - "code": -32603, - "message": "Internal error", - "data": {"detail": str(e)}, - }, - }, - ) - - def _create_response(self, result: Any) -> Union[JSONResponse, StreamingResponse]: - """ - Create appropriate response based on result type. - - Args: - result: Result from task manager - - Returns: - JSON or streaming response - """ - if isinstance(result, AsyncIterable): - # Result in streaming (SSE) - async def event_generator(): - async for item in result: - if hasattr(item, "model_dump_json"): - yield {"data": item.model_dump_json(exclude_none=True)} - else: - yield {"data": json.dumps(item)} - - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - - elif hasattr(result, "model_dump"): - # Result is a Pydantic object - return JSONResponse(result.model_dump(exclude_none=True)) - - else: - # Result is a dictionary or other simple type - return JSONResponse(result) - - async def get_agent_card( - self, request: Request, db: Optional[Session] = None - ) -> JSONResponse: - """ - Get the agent card. - - Args: - request: HTTP request - db: Database session - - Returns: - Agent card as JSON - """ - if not self.agent_card: - logger.error("Agent card not configured") - return JSONResponse( - status_code=404, content={"error": "Agent card not configured"} - ) - - # If there is db, set it temporarily in the task_manager - if db and hasattr(self.task_manager, "db"): - original_db = self.task_manager.db - try: - self.task_manager.db = db - - # If it's a Pydantic object, convert to dictionary - if hasattr(self.agent_card, "model_dump"): - return JSONResponse(self.agent_card.model_dump(exclude_none=True)) - else: - return JSONResponse(self.agent_card) - finally: - # Restore the original db - self.task_manager.db = original_db - else: - # If it's a Pydantic object, convert to dictionary - if hasattr(self.agent_card, "model_dump"): - return JSONResponse(self.agent_card.model_dump(exclude_none=True)) - else: - return JSONResponse(self.agent_card) diff --git a/src/services/a2a_task_manager.py b/src/services/a2a_task_manager.py new file mode 100644 index 00000000..08280270 --- /dev/null +++ b/src/services/a2a_task_manager.py @@ -0,0 +1,644 @@ +import json +import logging +import asyncio +from collections.abc import AsyncIterable +from typing import Any, Dict, Optional, Union, List +from uuid import UUID + +from sqlalchemy.orm import Session + +from src.config.settings import settings +from src.services.agent_service import ( + get_agent, + create_agent, + update_agent, + delete_agent, + get_agents_by_client, +) +from src.services.mcp_server_service import get_mcp_server +from src.services.session_service import ( + get_sessions_by_client, + get_sessions_by_agent, + get_session_by_id, + delete_session, + get_session_events, +) + +from src.services.agent_runner import run_agent, run_agent_stream +from src.services.service_providers import ( + session_service, + artifacts_service, + memory_service, +) +from src.models.models import Agent +from src.schemas.a2a_types import ( + A2ARequest, + GetTaskRequest, + SendTaskRequest, + SendTaskResponse, + SendTaskStreamingRequest, + SendTaskStreamingResponse, + CancelTaskRequest, + SetTaskPushNotificationRequest, + GetTaskPushNotificationRequest, + TaskResubscriptionRequest, + JSONRPCResponse, + TaskStatusUpdateEvent, + TaskArtifactUpdateEvent, + Task, + TaskSendParams, + InternalError, + Message, + Artifact, + TaskStatus, + TaskState, + AgentCard, + AgentCapabilities, + AgentSkill, +) + +logger = logging.getLogger(__name__) + + +class A2ATaskManager: + """Task manager for the A2A protocol.""" + + def __init__(self, db: Session): + self.db = db + self.tasks: Dict[str, Task] = {} + self.lock = asyncio.Lock() + + async def upsert_task(self, task_params: TaskSendParams) -> Task: + """Creates or updates a task in the store.""" + async with self.lock: + task = self.tasks.get(task_params.id) + if task is None: + # Create new task with initial history + task = Task( + id=task_params.id, + sessionId=task_params.sessionId, + status=TaskStatus(state=TaskState.SUBMITTED), + history=[task_params.message], + artifacts=[], + ) + self.tasks[task_params.id] = task + else: + # Add message to existing history + if task.history is None: + task.history = [] + task.history.append(task_params.message) + + return task + + async def on_get_task(self, request: GetTaskRequest) -> JSONRPCResponse: + """Handles requests to get task details.""" + try: + task_id = request.params.id + history_length = request.params.historyLength + + async with self.lock: + if task_id not in self.tasks: + return JSONRPCResponse( + id=request.id, + error=InternalError(message=f"Task {task_id} not found"), + ) + + # Get the task and limit the history as requested + task = self.tasks[task_id] + task_result = self.append_task_history(task, history_length) + + return SendTaskResponse(id=request.id, result=task_result) + except Exception as e: + logger.error(f"Error getting task: {e}") + return JSONRPCResponse( + id=request.id, + error=InternalError(message=f"Error getting task: {str(e)}"), + ) + + async def on_send_task( + self, request: SendTaskRequest, agent_id: UUID + ) -> JSONRPCResponse: + """Handles requests to send a task for processing.""" + try: + agent = get_agent(self.db, agent_id) + if not agent: + return JSONRPCResponse( + id=request.id, + error=InternalError(message=f"Agent {agent_id} not found"), + ) + + await self.upsert_task(request.params) + return await self._process_task(request, agent) + except Exception as e: + logger.error(f"Error sending task: {e}") + return JSONRPCResponse( + id=request.id, + error=InternalError(message=f"Error sending task: {str(e)}"), + ) + + async def on_send_task_subscribe( + self, request: SendTaskStreamingRequest, agent_id: UUID + ) -> AsyncIterable[SendTaskStreamingResponse]: + """Handles requests to send a task and subscribe to updates.""" + try: + agent = get_agent(self.db, agent_id) + if not agent: + yield JSONRPCResponse( + id=request.id, + error=InternalError(message=f"Agent {agent_id} not found"), + ) + return + + await self.upsert_task(request.params) + async for response in self._stream_task_process(request, agent): + yield response + except Exception as e: + logger.error(f"Error processing streaming task: {e}") + yield JSONRPCResponse( + id=request.id, + error=InternalError( + message=f"Error processing streaming task: {str(e)}" + ), + ) + + async def on_cancel_task(self, request: CancelTaskRequest) -> JSONRPCResponse: + """Handles requests to cancel a task.""" + try: + task_id = request.params.id + async with self.lock: + if task_id in self.tasks: + task = self.tasks[task_id] + task.status = TaskStatus(state=TaskState.CANCELED) + return JSONRPCResponse(id=request.id, result=True) + return JSONRPCResponse( + id=request.id, + error=InternalError(message=f"Task {task_id} not found"), + ) + except Exception as e: + logger.error(f"Error canceling task: {e}") + return JSONRPCResponse( + id=request.id, + error=InternalError(message=f"Error canceling task: {str(e)}"), + ) + + async def on_set_task_push_notification( + self, request: SetTaskPushNotificationRequest + ) -> JSONRPCResponse: + """Handles requests to configure push notifications for a task.""" + return JSONRPCResponse(id=request.id, result=True) + + async def on_get_task_push_notification( + self, request: GetTaskPushNotificationRequest + ) -> JSONRPCResponse: + """Handles requests to get push notification settings for a task.""" + return JSONRPCResponse(id=request.id, result={}) + + async def on_resubscribe_to_task( + self, request: TaskResubscriptionRequest + ) -> AsyncIterable[SendTaskStreamingResponse]: + """Handles requests to resubscribe to a task.""" + task_id = request.params.id + try: + async with self.lock: + if task_id not in self.tasks: + yield SendTaskStreamingResponse( + id=request.id, + error=InternalError(message=f"Task {task_id} not found"), + ) + return + + task = self.tasks[task_id] + + yield SendTaskStreamingResponse( + id=request.id, + result=TaskStatusUpdateEvent( + id=task_id, + status=task.status, + final=False, + ), + ) + + if task.artifacts: + for artifact in task.artifacts: + yield SendTaskStreamingResponse( + id=request.id, + result=TaskArtifactUpdateEvent(id=task_id, artifact=artifact), + ) + + yield SendTaskStreamingResponse( + id=request.id, + result=TaskStatusUpdateEvent( + id=task_id, + status=TaskStatus(state=task.status.state), + final=True, + ), + ) + + except Exception as e: + logger.error(f"Error resubscribing to task: {e}") + yield SendTaskStreamingResponse( + id=request.id, + error=InternalError(message=f"Error resubscribing to task: {str(e)}"), + ) + + async def _process_task( + self, request: SendTaskRequest, agent: Agent + ) -> JSONRPCResponse: + """Processes a task using the specified agent.""" + task_params = request.params + query = self._extract_user_query(task_params) + + try: + # Process the query with the agent + result = await self._run_agent(agent, query, task_params.sessionId) + + # Create the response part + text_part = {"type": "text", "text": result} + parts = [text_part] + agent_message = Message(role="agent", parts=parts) + + # Determine the task state + task_state = ( + TaskState.INPUT_REQUIRED + if "MISSING_INFO:" in result + else TaskState.COMPLETED + ) + + # Update the task in the store + task = await self.update_store( + task_params.id, + TaskStatus(state=task_state, message=agent_message), + [Artifact(parts=parts, index=0)], + ) + + return SendTaskResponse(id=request.id, result=task) + except Exception as e: + logger.error(f"Error processing task: {e}") + return JSONRPCResponse( + id=request.id, + error=InternalError(message=f"Error processing task: {str(e)}"), + ) + + async def _stream_task_process( + self, request: SendTaskStreamingRequest, agent: Agent + ) -> AsyncIterable[SendTaskStreamingResponse]: + """Processes a task in streaming mode using the specified agent.""" + task_params = request.params + query = self._extract_user_query(task_params) + + try: + # Send initial processing status + processing_text_part = { + "type": "text", + "text": "Processing your request...", + } + processing_message = Message(role="agent", parts=[processing_text_part]) + + # Update the task with the processing message and inform the WORKING state + await self.update_store( + task_params.id, + TaskStatus(state=TaskState.WORKING, message=processing_message), + ) + + yield SendTaskStreamingResponse( + id=request.id, + result=TaskStatusUpdateEvent( + id=task_params.id, + status=TaskStatus( + state=TaskState.WORKING, + message=processing_message, + ), + final=False, + ), + ) + + # Collect the chunks of the agent's response + contact_id = task_params.sessionId + full_response = "" + + # We use the same streaming function used in the WebSocket + async for chunk in run_agent_stream( + agent_id=str(agent.id), + contact_id=contact_id, + message=query, + session_service=session_service, + artifacts_service=artifacts_service, + memory_service=memory_service, + db=self.db, + ): + # Send incremental progress updates + update_text_part = {"type": "text", "text": chunk} + update_message = Message(role="agent", parts=[update_text_part]) + + # Update the task with each intermediate message + await self.update_store( + task_params.id, + TaskStatus(state=TaskState.WORKING, message=update_message), + ) + + yield SendTaskStreamingResponse( + id=request.id, + result=TaskStatusUpdateEvent( + id=task_params.id, + status=TaskStatus( + state=TaskState.WORKING, + message=update_message, + ), + final=False, + ), + ) + full_response += chunk + + # Determine the task state + task_state = ( + TaskState.INPUT_REQUIRED + if "MISSING_INFO:" in full_response + else TaskState.COMPLETED + ) + + # Create the final response part + final_text_part = {"type": "text", "text": full_response} + parts = [final_text_part] + final_message = Message(role="agent", parts=parts) + + # Create the final artifact from the final response + final_artifact = Artifact(parts=parts, index=0) + + # Update the task in the store with the final response + task = await self.update_store( + task_params.id, + TaskStatus(state=task_state, message=final_message), + [final_artifact], + ) + + # Send the final artifact + yield SendTaskStreamingResponse( + id=request.id, + result=TaskArtifactUpdateEvent( + id=task_params.id, artifact=final_artifact + ), + ) + + # Send the final status + yield SendTaskStreamingResponse( + id=request.id, + result=TaskStatusUpdateEvent( + id=task_params.id, + status=TaskStatus(state=task_state), + final=True, + ), + ) + except Exception as e: + logger.error(f"Error streaming task process: {e}") + yield JSONRPCResponse( + id=request.id, + error=InternalError(message=f"Error streaming task process: {str(e)}"), + ) + + async def update_store( + self, + task_id: str, + status: TaskStatus, + artifacts: Optional[list[Artifact]] = None, + ) -> Task: + """Updates the status and artifacts of a task.""" + async with self.lock: + if task_id not in self.tasks: + raise ValueError(f"Task {task_id} not found") + + task = self.tasks[task_id] + task.status = status + + # Add message to history if it exists + if status.message is not None: + if task.history is None: + task.history = [] + task.history.append(status.message) + + if artifacts is not None: + if task.artifacts is None: + task.artifacts = [] + task.artifacts.extend(artifacts) + + return task + + def _extract_user_query(self, task_params: TaskSendParams) -> str: + """Extracts the user query from the task parameters.""" + if not task_params.message or not task_params.message.parts: + raise ValueError("Message or parts are missing in task parameters") + + part = task_params.message.parts[0] + if part.type != "text": + raise ValueError("Only text parts are supported") + + return part.text + + async def _run_agent(self, agent: Agent, query: str, session_id: str) -> str: + """Executes the agent to process the user query.""" + try: + # We use the session_id as contact_id to maintain the conversation continuity + contact_id = session_id + + # We call the same function used in the chat API + final_response = await run_agent( + agent_id=str(agent.id), + contact_id=contact_id, + message=query, + session_service=session_service, + artifacts_service=artifacts_service, + memory_service=memory_service, + db=self.db, + ) + + return final_response + except Exception as e: + logger.error(f"Error running agent: {e}") + raise ValueError(f"Error running agent: {str(e)}") + + def append_task_history(self, task: Task, history_length: int | None) -> Task: + """Returns a copy of the task with the history limited to the specified size.""" + # Create a copy of the task + new_task = task.model_copy() + + # Limit the history if requested + if history_length is not None: + if history_length > 0: + new_task.history = ( + new_task.history[-history_length:] if new_task.history else [] + ) + else: + new_task.history = [] + + return new_task + + +class A2AService: + """Service to manage A2A requests and agent cards.""" + + def __init__(self, db: Session, task_manager: A2ATaskManager): + self.db = db + self.task_manager = task_manager + + async def process_request( + self, agent_id: UUID, request_body: dict + ) -> JSONRPCResponse: + """Processes an A2A request.""" + try: + request = A2ARequest.validate_python(request_body) + + if isinstance(request, GetTaskRequest): + return await self.task_manager.on_get_task(request) + elif isinstance(request, SendTaskRequest): + return await self.task_manager.on_send_task(request, agent_id) + elif isinstance(request, SendTaskStreamingRequest): + return self.task_manager.on_send_task_subscribe(request, agent_id) + elif isinstance(request, CancelTaskRequest): + return await self.task_manager.on_cancel_task(request) + elif isinstance(request, SetTaskPushNotificationRequest): + return await self.task_manager.on_set_task_push_notification(request) + elif isinstance(request, GetTaskPushNotificationRequest): + return await self.task_manager.on_get_task_push_notification(request) + elif isinstance(request, TaskResubscriptionRequest): + return self.task_manager.on_resubscribe_to_task(request) + else: + logger.warning(f"Unexpected request type: {type(request)}") + return JSONRPCResponse( + id=getattr(request, "id", None), + error=InternalError( + message=f"Unexpected request type: {type(request)}" + ), + ) + except Exception as e: + logger.error(f"Error processing A2A request: {e}") + return JSONRPCResponse( + id=None, + error=InternalError(message=f"Error processing A2A request: {str(e)}"), + ) + + def get_agent_card(self, agent_id: UUID) -> AgentCard: + """Gets the agent card for the specified agent.""" + agent = get_agent(self.db, agent_id) + if not agent: + raise ValueError(f"Agent {agent_id} not found") + + # Build the agent card based on the agent's information + capabilities = AgentCapabilities(streaming=True) + + # List to store all skills + skills = [] + + # Check if the agent has MCP servers configured + if ( + agent.config + and "mcp_servers" in agent.config + and agent.config["mcp_servers"] + ): + logger.info( + f"Agent {agent_id} has {len(agent.config['mcp_servers'])} MCP servers configured" + ) + + for mcp_config in agent.config["mcp_servers"]: + # Get the MCP server + mcp_server_id = mcp_config.get("id") + if not mcp_server_id: + logger.warning("MCP server configuration missing ID") + continue + + logger.info(f"Processing MCP server: {mcp_server_id}") + mcp_server = get_mcp_server(self.db, mcp_server_id) + if not mcp_server: + logger.warning(f"MCP server {mcp_server_id} not found") + continue + + # Get the available tools in the MCP server + mcp_tools = mcp_config.get("tools", []) + logger.info(f"MCP server {mcp_server.name} has tools: {mcp_tools}") + + # Add server tools as skills + for tool_name in mcp_tools: + logger.info(f"Processing tool: {tool_name}") + + # Buscar informações da ferramenta pelo ID + tool_info = None + if hasattr(mcp_server, "tools") and isinstance( + mcp_server.tools, list + ): + for tool in mcp_server.tools: + if isinstance(tool, dict) and tool.get("id") == tool_name: + tool_info = tool + logger.info( + f"Found tool info for {tool_name}: {tool_info}" + ) + break + + if tool_info: + # Use the information from the tool + skill = AgentSkill( + id=tool_info.get("id", f"{agent.id}_{tool_name}"), + name=tool_info.get("name", tool_name), + description=tool_info.get( + "description", f"Tool: {tool_name}" + ), + tags=tool_info.get( + "tags", [mcp_server.name, "tool", tool_name] + ), + examples=tool_info.get("examples", []), + inputModes=tool_info.get("inputModes", ["text"]), + outputModes=tool_info.get("outputModes", ["text"]), + ) + else: + # Default skill if tool info not found + skill = AgentSkill( + id=f"{agent.id}_{tool_name}", + name=tool_name, + description=f"Tool: {tool_name}", + tags=[mcp_server.name, "tool", tool_name], + examples=[], + inputModes=["text"], + outputModes=["text"], + ) + + skills.append(skill) + logger.info(f"Added skill for tool: {tool_name}") + + # Check custom tools + if ( + agent.config + and "custom_tools" in agent.config + and agent.config["custom_tools"] + ): + custom_tools = agent.config["custom_tools"] + + # Check HTTP tools + if "http_tools" in custom_tools and custom_tools["http_tools"]: + logger.info(f"Agent has {len(custom_tools['http_tools'])} HTTP tools") + for http_tool in custom_tools["http_tools"]: + skill = AgentSkill( + id=f"{agent.id}_http_{http_tool['name']}", + name=http_tool["name"], + description=http_tool.get( + "description", f"HTTP Tool: {http_tool['name']}" + ), + tags=http_tool.get( + "tags", ["http", "custom_tool", http_tool["method"]] + ), + examples=http_tool.get("examples", []), + inputModes=http_tool.get("inputModes", ["text"]), + outputModes=http_tool.get("outputModes", ["text"]), + ) + skills.append(skill) + logger.info(f"Added skill for HTTP tool: {http_tool['name']}") + + card = AgentCard( + name=agent.name, + description=agent.description or "", + url=f"{settings.API_URL}/api/v1/a2a/{agent_id}", + version="1.0.0", + defaultInputModes=["text"], + defaultOutputModes=["text"], + capabilities=capabilities, + skills=skills, + ) + + logger.info(f"Generated agent card with {len(skills)} skills") + return card diff --git a/src/services/a2a_task_manager_service.py b/src/services/a2a_task_manager_service.py deleted file mode 100644 index 1857949c..00000000 --- a/src/services/a2a_task_manager_service.py +++ /dev/null @@ -1,949 +0,0 @@ -""" -A2A Task Manager Service. - -This service implements task management for the A2A protocol, handling task lifecycle -including execution, streaming, push notifications, status queries, and cancellation. -""" - -import asyncio -import logging -from datetime import datetime -from typing import Any, Dict, Union, AsyncIterable -import uuid - -from src.schemas.a2a.exceptions import ( - TaskNotFoundError, - TaskNotCancelableError, - PushNotificationNotSupportedError, - InternalError, - ContentTypeNotSupportedError, -) - -from src.schemas.a2a.types import ( - JSONRPCResponse, - GetTaskRequest, - SendTaskRequest, - CancelTaskRequest, - SetTaskPushNotificationRequest, - GetTaskPushNotificationRequest, - GetTaskResponse, - CancelTaskResponse, - SendTaskResponse, - SetTaskPushNotificationResponse, - GetTaskPushNotificationResponse, - TaskSendParams, - TaskStatus, - TaskState, - TaskResubscriptionRequest, - SendTaskStreamingRequest, - SendTaskStreamingResponse, - Artifact, - PushNotificationConfig, - TaskStatusUpdateEvent, - TaskArtifactUpdateEvent, - JSONRPCError, - TaskPushNotificationConfig, - Message, - TextPart, - Task, -) -from src.services.redis_cache_service import RedisCacheService -from src.utils.a2a_utils import ( - are_modalities_compatible, - new_incompatible_types_error, -) - -logger = logging.getLogger(__name__) - - -class A2ATaskManager: - """ - A2A Task Manager implementation. - - This class manages the lifecycle of A2A tasks, including: - - Task submission and execution - - Task status queries - - Task cancellation - - Push notification configuration - - SSE streaming for real-time updates - """ - - def __init__( - self, - redis_cache: RedisCacheService, - agent_runner=None, - streaming_service=None, - push_notification_service=None, - db=None, - ): - """ - Initialize the A2A Task Manager. - - Args: - redis_cache: Redis cache service for task storage - agent_runner: Agent runner service for task execution - streaming_service: Streaming service for SSE - push_notification_service: Service for push notifications - db: Database session - """ - self.redis_cache = redis_cache - self.agent_runner = agent_runner - self.streaming_service = streaming_service - self.push_notification_service = push_notification_service - self.db = db - self.lock = asyncio.Lock() - self.subscriber_lock = asyncio.Lock() - self.task_sse_subscribers = {} - - async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse: - """ - Handle request to get task information. - - Args: - request: A2A Get Task request - - Returns: - Response with task details - """ - try: - task_id = request.params.id - history_length = request.params.historyLength - - # Get task data from cache - task_data = await self.redis_cache.get(f"task:{task_id}") - - if not task_data: - logger.warning(f"Task not found: {task_id}") - return GetTaskResponse(id=request.id, error=TaskNotFoundError()) - - # Create a Task instance from cache data - task = Task.model_validate(task_data) - - # If historyLength parameter is present, handle the history - if history_length is not None and task.history: - if history_length == 0: - task.history = [] - elif len(task.history) > history_length: - task.history = task.history[-history_length:] - - return GetTaskResponse(id=request.id, result=task) - except Exception as e: - logger.error(f"Error processing on_get_task: {str(e)}") - return GetTaskResponse(id=request.id, error=InternalError(message=str(e))) - - async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: - """ - Handle request to cancel a running task. - - Args: - request: The JSON-RPC request to cancel a task - - Returns: - Response with updated task data or error - """ - logger.info(f"Cancelling task {request.params.id}") - task_id_params = request.params - - try: - task_data = await self.redis_cache.get(f"task:{task_id_params.id}") - if not task_data: - logger.warning(f"Task {task_id_params.id} not found for cancellation") - return CancelTaskResponse(id=request.id, error=TaskNotFoundError()) - - # Check if task can be cancelled - current_state = task_data.get("status", {}).get("state") - if current_state not in [TaskState.SUBMITTED, TaskState.WORKING]: - logger.warning( - f"Task {task_id_params.id} in state {current_state} cannot be cancelled" - ) - return CancelTaskResponse(id=request.id, error=TaskNotCancelableError()) - - # Update task status to cancelled - task_data["status"] = { - "state": TaskState.CANCELED, - "timestamp": datetime.now().isoformat(), - "message": { - "role": "agent", - "parts": [{"type": "text", "text": "Task cancelled by user"}], - }, - } - - # Save updated task data - await self.redis_cache.set(f"task:{task_id_params.id}", task_data) - - # Send push notification if configured - await self._send_push_notification_for_task( - task_id_params.id, "canceled", system_message="Task cancelled by user" - ) - - # Publish event to SSE subscribers - await self._publish_task_update( - task_id_params.id, - TaskStatusUpdateEvent( - id=task_id_params.id, - status=TaskStatus( - state=TaskState.CANCELED, - timestamp=datetime.now(), - message=Message( - role="agent", - parts=[TextPart(text="Task cancelled by user")], - ), - ), - final=True, - ), - ) - - return CancelTaskResponse(id=request.id, result=task_data) - - except Exception as e: - logger.error(f"Error cancelling task: {str(e)}", exc_info=True) - return CancelTaskResponse( - id=request.id, - error=InternalError(message=f"Error cancelling task: {str(e)}"), - ) - - async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: - """ - Handle request to send a new task. - - Args: - request: Send Task request - - Returns: - Response with the created task details - """ - try: - params = request.params - task_id = params.id - logger.info(f"Receiving task {task_id}") - - # Check if a task with this ID already exists - existing_task = await self.redis_cache.get(f"task:{task_id}") - if existing_task: - # If the task already exists and is in progress, return the current task - if existing_task.get("status", {}).get("state") in [ - TaskState.WORKING, - TaskState.COMPLETED, - ]: - return SendTaskResponse( - id=request.id, result=Task.model_validate(existing_task) - ) - - # If the task exists but failed or was canceled, we can reprocess it - logger.info(f"Reprocessing existing task {task_id}") - - # Check modality compatibility - server_output_modes = [] - if self.agent_runner: - # Try to get supported modes from the agent - try: - server_output_modes = await self.agent_runner.get_supported_modes() - except Exception as e: - logger.warning(f"Error getting supported modes: {str(e)}") - server_output_modes = ["text"] # Fallback to text - - if not are_modalities_compatible( - server_output_modes, params.acceptedOutputModes - ): - logger.warning( - f"Incompatible modes: server={server_output_modes}, client={params.acceptedOutputModes}" - ) - return SendTaskResponse( - id=request.id, error=ContentTypeNotSupportedError() - ) - - # Create task data - task_data = await self._create_task_data(params) - - # Store task in cache - await self.redis_cache.set(f"task:{task_id}", task_data) - - # Configure push notifications, if provided - if params.pushNotification: - await self.redis_cache.set( - f"task_notification:{task_id}", params.pushNotification.model_dump() - ) - - # Execute task SYNCHRONOUSLY instead of in background - # This is the key change for A2A compatibility - task_data = await self._execute_task(task_data, params) - - # Convert to Task object and return - task = Task.model_validate(task_data) - return SendTaskResponse(id=request.id, result=task) - - except Exception as e: - logger.error(f"Error processing on_send_task: {str(e)}") - return SendTaskResponse(id=request.id, error=InternalError(message=str(e))) - - async def on_send_task_subscribe( - self, request: SendTaskStreamingRequest - ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: - """ - Handle request to send a task and subscribe to streaming updates. - - Args: - request: The JSON-RPC request to send a task with streaming - - Returns: - Stream of events or error response - """ - logger.info(f"Sending task with streaming {request.params.id}") - task_send_params = request.params - - try: - # Check output mode compatibility - if not are_modalities_compatible( - ["text", "application/json"], # Default supported modes - task_send_params.acceptedOutputModes, - ): - return new_incompatible_types_error(request.id) - - # Create initial task data - task_data = await self._create_task_data(task_send_params) - - # Setup SSE consumer - sse_queue = await self._setup_sse_consumer(task_send_params.id) - - # Execute task asynchronously (fire and forget) - asyncio.create_task(self._execute_task(task_data, task_send_params)) - - # Return generator to dequeue events for SSE - return self._dequeue_events_for_sse( - request.id, task_send_params.id, sse_queue - ) - - except Exception as e: - logger.error(f"Error setting up streaming task: {str(e)}", exc_info=True) - return SendTaskStreamingResponse( - id=request.id, - error=InternalError( - message=f"Error setting up streaming task: {str(e)}" - ), - ) - - async def on_set_task_push_notification( - self, request: SetTaskPushNotificationRequest - ) -> SetTaskPushNotificationResponse: - """ - Configure push notifications for a task. - - Args: - request: The JSON-RPC request to set push notification - - Returns: - Response with configuration or error - """ - logger.info(f"Setting push notification for task {request.params.id}") - task_notification_params = request.params - - try: - if not self.push_notification_service: - logger.warning("Push notifications not supported") - return SetTaskPushNotificationResponse( - id=request.id, error=PushNotificationNotSupportedError() - ) - - # Check if task exists - task_data = await self.redis_cache.get( - f"task:{task_notification_params.id}" - ) - if not task_data: - logger.warning( - f"Task {task_notification_params.id} not found for setting push notification" - ) - return SetTaskPushNotificationResponse( - id=request.id, error=TaskNotFoundError() - ) - - # Save push notification config - config = { - "url": task_notification_params.pushNotificationConfig.url, - "headers": {}, # Add auth headers if needed - } - - await self.redis_cache.set( - f"task:{task_notification_params.id}:push", config - ) - - return SetTaskPushNotificationResponse( - id=request.id, result=task_notification_params - ) - - except Exception as e: - logger.error(f"Error setting push notification: {str(e)}", exc_info=True) - return SetTaskPushNotificationResponse( - id=request.id, - error=InternalError( - message=f"Error setting push notification: {str(e)}" - ), - ) - - async def on_get_task_push_notification( - self, request: GetTaskPushNotificationRequest - ) -> GetTaskPushNotificationResponse: - """ - Get push notification configuration for a task. - - Args: - request: The JSON-RPC request to get push notification config - - Returns: - Response with configuration or error - """ - logger.info(f"Getting push notification for task {request.params.id}") - task_params = request.params - - try: - if not self.push_notification_service: - logger.warning("Push notifications not supported") - return GetTaskPushNotificationResponse( - id=request.id, error=PushNotificationNotSupportedError() - ) - - # Check if task exists - task_data = await self.redis_cache.get(f"task:{task_params.id}") - if not task_data: - logger.warning( - f"Task {task_params.id} not found for getting push notification" - ) - return GetTaskPushNotificationResponse( - id=request.id, error=TaskNotFoundError() - ) - - # Get push notification config - config = await self.redis_cache.get(f"task:{task_params.id}:push") - if not config: - logger.warning(f"No push notification config for task {task_params.id}") - return GetTaskPushNotificationResponse( - id=request.id, - error=InternalError( - message="No push notification configuration found" - ), - ) - - result = TaskPushNotificationConfig( - id=task_params.id, - pushNotificationConfig=PushNotificationConfig( - url=config.get("url"), token=None, authentication=None - ), - ) - - return GetTaskPushNotificationResponse(id=request.id, result=result) - - except Exception as e: - logger.error(f"Error getting push notification: {str(e)}", exc_info=True) - return GetTaskPushNotificationResponse( - id=request.id, - error=InternalError( - message=f"Error getting push notification: {str(e)}" - ), - ) - - async def on_resubscribe_to_task( - self, request: TaskResubscriptionRequest - ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: - """ - Resubscribe to a task's streaming events. - - Args: - request: The JSON-RPC request to resubscribe - - Returns: - Stream of events or error response - """ - logger.info(f"Resubscribing to task {request.params.id}") - task_params = request.params - - try: - # Check if task exists - task_data = await self.redis_cache.get(f"task:{task_params.id}") - if not task_data: - logger.warning(f"Task {task_params.id} not found for resubscription") - return JSONRPCResponse(id=request.id, error=TaskNotFoundError()) - - # Setup SSE consumer with resubscribe flag - try: - sse_queue = await self._setup_sse_consumer( - task_params.id, is_resubscribe=True - ) - except ValueError: - logger.warning( - f"Task {task_params.id} not available for resubscription" - ) - return JSONRPCResponse( - id=request.id, - error=InternalError( - message="Task not available for resubscription" - ), - ) - - # Send initial status update to the new subscriber - status = task_data.get("status", {}) - final = status.get("state") in [ - TaskState.COMPLETED, - TaskState.FAILED, - TaskState.CANCELED, - ] - - # Convert to TaskStatus object - task_status = TaskStatus( - state=status.get("state", TaskState.UNKNOWN), - timestamp=datetime.fromisoformat( - status.get("timestamp", datetime.now().isoformat()) - ), - message=status.get("message"), - ) - - # Publish to the specific queue - await sse_queue.put( - TaskStatusUpdateEvent( - id=task_params.id, status=task_status, final=final - ) - ) - - # Return generator to dequeue events for SSE - return self._dequeue_events_for_sse(request.id, task_params.id, sse_queue) - - except Exception as e: - logger.error(f"Error resubscribing to task: {str(e)}", exc_info=True) - return JSONRPCResponse( - id=request.id, - error=InternalError(message=f"Error resubscribing to task: {str(e)}"), - ) - - async def _create_task_data(self, params: TaskSendParams) -> Dict[str, Any]: - """ - Create initial task data structure. - - Args: - params: Task send parameters - - Returns: - Task data dictionary - """ - # Create task with initial status - task_data = { - "id": params.id, - "sessionId": params.sessionId - or str(uuid.uuid4()), # Preservar sessionId quando fornecido - "status": { - "state": TaskState.SUBMITTED, - "timestamp": datetime.now().isoformat(), - "message": None, - "error": None, - }, - "artifacts": [], - "history": [params.message.model_dump()], # Apenas mensagem do usuário - "metadata": params.metadata or {}, - } - - # Save task data - await self.redis_cache.set(f"task:{params.id}", task_data) - - return task_data - - async def _execute_task( - self, task: Dict[str, Any], params: TaskSendParams - ) -> Dict[str, Any]: - """ - Execute a task using the agent adapter. - - This function is responsible for executing the task by the agent, - updating its status as progress is made. - - Args: - task: Task data to be executed - params: Send task parameters - - Returns: - Updated task data with completed status and response - """ - task_id = task["id"] - agent_id = params.agentId - message_text = "" - - # Extract the text from the message - if params.message and params.message.parts: - for part in params.message.parts: - if part.type == "text": - message_text += part.text - - if not message_text: - await self._update_task_status_without_history( - task_id, TaskState.FAILED, "Message does not contain text", final=True - ) - # Return the updated task data - return await self.redis_cache.get(f"task:{task_id}") - - # Check if it is an ongoing execution - task_status = task.get("status", {}) - if task_status.get("state") in [TaskState.WORKING, TaskState.COMPLETED]: - logger.info(f"Task {task_id} is already in execution or completed") - # Return the current task data - return await self.redis_cache.get(f"task:{task_id}") - - try: - # Update to "working" state - NÃO adicionar ao histórico - await self._update_task_status_without_history( - task_id, TaskState.WORKING, "Processing request" - ) - - # Execute the agent - if self.agent_runner: - response = await self.agent_runner.run_agent( - agent_id=agent_id, - message=message_text, - session_id=params.sessionId, # Usar o sessionId da requisição - task_id=task_id, - ) - - # Process the agent's response - if response and isinstance(response, dict): - # Extract text from the response - response_text = response.get("content", "") - if not response_text and "message" in response: - message = response.get("message", {}) - parts = message.get("parts", []) - for part in parts: - if part.get("type") == "text": - response_text += part.get("text", "") - - # Build the final agent message - if response_text: - # Atualizar o histórico com a mensagem do usuário - await self._update_task_history(task_id, params.message) - - # Create an artifact for the response in Google A2A format - artifact = { - "parts": [{"type": "text", "text": response_text}], - "index": 0, - } - - # Add the artifact to the task - await self._add_task_artifact(task_id, artifact) - - # Update the task status to completed (sem adicionar ao histórico) - await self._update_task_status_without_history( - task_id, TaskState.COMPLETED, response_text, final=True - ) - else: - await self._update_task_status_without_history( - task_id, - TaskState.FAILED, - "The agent did not return a valid response", - final=True, - ) - else: - await self._update_task_status_without_history( - task_id, - TaskState.FAILED, - "Invalid agent response", - final=True, - ) - else: - await self._update_task_status_without_history( - task_id, - TaskState.FAILED, - "Agent adapter not configured", - final=True, - ) - except Exception as e: - logger.error(f"Error executing task {task_id}: {str(e)}") - await self._update_task_status_without_history( - task_id, TaskState.FAILED, f"Error processing: {str(e)}", final=True - ) - - # Return the updated task data - return await self.redis_cache.get(f"task:{task_id}") - - async def _update_task_history(self, task_id: str, message) -> None: - """ - Atualiza o histórico da tarefa incluindo apenas mensagens do usuário. - - Args: - task_id: ID da tarefa - message: Mensagem do usuário para adicionar ao histórico - """ - if not message: - return - - # Obter dados da tarefa atual - task_data = await self.redis_cache.get(f"task:{task_id}") - if not task_data: - logger.warning(f"Task {task_id} not found for history update") - return - - # Garantir que há um campo de histórico - if "history" not in task_data: - task_data["history"] = [] - - # Verificar se a mensagem já existe no histórico para evitar duplicação - user_message = ( - message.model_dump() if hasattr(message, "model_dump") else message - ) - message_exists = False - - for msg in task_data["history"]: - if self._compare_messages(msg, user_message): - message_exists = True - break - - # Adicionar mensagem se não existir - if not message_exists: - task_data["history"].append(user_message) - logger.info(f"Added new user message to history for task {task_id}") - - # Salvar tarefa atualizada - await self.redis_cache.set(f"task:{task_id}", task_data) - - # Método auxiliar para comparar mensagens e evitar duplicação no histórico - def _compare_messages(self, msg1: Dict, msg2: Dict) -> bool: - """Compara duas mensagens para verificar se são essencialmente iguais.""" - if msg1.get("role") != msg2.get("role"): - return False - - parts1 = msg1.get("parts", []) - parts2 = msg2.get("parts", []) - - if len(parts1) != len(parts2): - return False - - for i in range(len(parts1)): - if parts1[i].get("type") != parts2[i].get("type"): - return False - if parts1[i].get("text") != parts2[i].get("text"): - return False - - return True - - async def _update_task_status_without_history( - self, task_id: str, state: TaskState, message_text: str, final: bool = False - ) -> None: - """ - Update the status of a task without changing the history. - - Args: - task_id: ID of the task to be updated - state: New task state - message_text: Text of the message associated with the status - final: Indicates if this is the final status of the task - """ - try: - # Get current task data - task_data = await self.redis_cache.get(f"task:{task_id}") - if not task_data: - logger.warning(f"Unable to update status: task {task_id} not found") - return - - # Create status object with the message - agent_message = Message( - role="agent", - parts=[TextPart(text=message_text)], - ) - - status = TaskStatus( - state=state, message=agent_message, timestamp=datetime.now() - ) - - # Update the status in the task - task_data["status"] = status.model_dump(exclude_none=True) - - # Store the updated task - await self.redis_cache.set(f"task:{task_id}", task_data) - - # Create status update event - status_event = TaskStatusUpdateEvent(id=task_id, status=status, final=final) - - # Publish status update - await self._publish_task_update(task_id, status_event) - - # Send push notification, if configured - if final or state in [ - TaskState.FAILED, - TaskState.COMPLETED, - TaskState.CANCELED, - ]: - await self._send_push_notification_for_task( - task_id=task_id, state=state, message_text=message_text - ) - except Exception as e: - logger.error(f"Error updating task status {task_id}: {str(e)}") - - async def _add_task_artifact(self, task_id: str, artifact) -> None: - """ - Add an artifact to a task and publish the update. - - Args: - task_id: Task ID - artifact: Artifact to add (dict no formato do Google) - """ - logger.info(f"Adding artifact to task {task_id}") - - # Update task data - task_data = await self.redis_cache.get(f"task:{task_id}") - if task_data: - if "artifacts" not in task_data: - task_data["artifacts"] = [] - - # Adicionar o artefato sem substituir os existentes - task_data["artifacts"].append(artifact) - await self.redis_cache.set(f"task:{task_id}", task_data) - - # Criar um artefato do tipo Artifact para o evento - artifact_obj = Artifact( - parts=[ - TextPart(text=part.get("text", "")) - for part in artifact.get("parts", []) - ], - index=artifact.get("index", 0), - ) - - # Create artifact update event - event = TaskArtifactUpdateEvent(id=task_id, artifact=artifact_obj) - - # Publish event - await self._publish_task_update(task_id, event) - - async def _publish_task_update( - self, task_id: str, event: Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent] - ) -> None: - """ - Publish a task update event to all subscribers. - - Args: - task_id: Task ID - event: Event to publish - """ - async with self.subscriber_lock: - if task_id not in self.task_sse_subscribers: - return - - subscribers = self.task_sse_subscribers[task_id] - for subscriber in subscribers: - try: - await subscriber.put(event) - except Exception as e: - logger.error(f"Error publishing event to subscriber: {str(e)}") - - async def _send_push_notification_for_task( - self, - task_id: str, - state: str, - message_text: str = None, - system_message: str = None, - ) -> None: - """ - Send push notification for a task if configured. - - Args: - task_id: Task ID - state: Task state - message_text: Optional message text - system_message: Optional system message - """ - if not self.push_notification_service: - return - - try: - # Get push notification config - config = await self.redis_cache.get(f"task:{task_id}:push") - if not config: - return - - # Create message if provided - message = None - if message_text: - message = { - "role": "agent", - "parts": [{"type": "text", "text": message_text}], - } - elif system_message: - # We use 'agent' instead of 'system' since Message only accepts 'user' or 'agent' - message = { - "role": "agent", - "parts": [{"type": "text", "text": system_message}], - } - - # Send notification - await self.push_notification_service.send_notification( - url=config["url"], - task_id=task_id, - state=state, - message=message, - headers=config.get("headers", {}), - ) - - except Exception as e: - logger.error( - f"Error sending push notification for task {task_id}: {str(e)}" - ) - - async def _setup_sse_consumer( - self, task_id: str, is_resubscribe: bool = False - ) -> asyncio.Queue: - """ - Set up an SSE consumer queue for a task. - - Args: - task_id: Task ID - is_resubscribe: Whether this is a resubscription - - Returns: - Queue for events - - Raises: - ValueError: If resubscribing to non-existent task - """ - async with self.subscriber_lock: - if task_id not in self.task_sse_subscribers: - if is_resubscribe: - raise ValueError("Task not found for resubscription") - self.task_sse_subscribers[task_id] = [] - - queue = asyncio.Queue() - self.task_sse_subscribers[task_id].append(queue) - return queue - - async def _dequeue_events_for_sse( - self, request_id: str, task_id: str, event_queue: asyncio.Queue - ) -> AsyncIterable[SendTaskStreamingResponse]: - """ - Dequeue and yield events for SSE streaming. - - Args: - request_id: Request ID - task_id: Task ID - event_queue: Queue for events - - Yields: - SSE events wrapped in SendTaskStreamingResponse - """ - try: - while True: - event = await event_queue.get() - - if isinstance(event, JSONRPCError): - yield SendTaskStreamingResponse(id=request_id, error=event) - break - - yield SendTaskStreamingResponse(id=request_id, result=event) - - # Check if this is the final event - is_final = False - if hasattr(event, "final") and event.final: - is_final = True - - if is_final: - break - finally: - # Clean up the subscription when done - async with self.subscriber_lock: - if task_id in self.task_sse_subscribers: - try: - self.task_sse_subscribers[task_id].remove(event_queue) - # Remove the task from the dict if no more subscribers - if not self.task_sse_subscribers[task_id]: - del self.task_sse_subscribers[task_id] - except ValueError: - pass # Queue might have been removed already diff --git a/src/services/push_notification_auth_service.py b/src/services/push_notification_auth_service.py deleted file mode 100644 index 228a8215..00000000 --- a/src/services/push_notification_auth_service.py +++ /dev/null @@ -1,265 +0,0 @@ -""" -Push Notification Authentication Service. - -This service implements JWT authentication for A2A push notifications, -allowing agents to send authenticated notifications and clients to verify -the authenticity of received notifications. -""" - -from jwcrypto import jwk -import uuid -import time -import json -import hashlib -import httpx -import logging -import jwt -from jwt import PyJWK, PyJWKClient -from fastapi import Request -from starlette.responses import JSONResponse -from typing import Dict, Any - -logger = logging.getLogger(__name__) -AUTH_HEADER_PREFIX = "Bearer " - - -class PushNotificationAuth: - """ - Base class for push notification authentication. - - Contains common methods used by both the sender and the receiver - of push notifications. - """ - - def _calculate_request_body_sha256(self, data: Dict[str, Any]) -> str: - """ - Calculates the SHA256 hash of the request body. - - This logic needs to be the same for the agent that signs the payload - and for the client that verifies it. - - Args: - data: Request body data - - Returns: - SHA256 hash as a hexadecimal string - """ - body_str = json.dumps( - data, - ensure_ascii=False, - allow_nan=False, - indent=None, - separators=(",", ":"), - ) - return hashlib.sha256(body_str.encode()).hexdigest() - - -class PushNotificationSenderAuth(PushNotificationAuth): - """ - Authentication for the push notification sender. - - This class is used by the A2A server to authenticate notifications - sent to callback URLs of clients. - """ - - def __init__(self): - """ - Initialize the push notification sender authentication service. - """ - self.public_keys = [] - self.private_key_jwk = None - - @staticmethod - async def verify_push_notification_url(url: str) -> bool: - """ - Verifies if a push notification URL is valid and responds correctly. - - Sends a validation token and verifies if the response contains the same token. - - Args: - url: URL to be verified - - Returns: - True if the URL is verified successfully, False otherwise - """ - async with httpx.AsyncClient(timeout=10) as client: - try: - validation_token = str(uuid.uuid4()) - response = await client.get( - url, params={"validationToken": validation_token} - ) - response.raise_for_status() - is_verified = response.text == validation_token - - logger.info(f"Push notification URL verified: {url} => {is_verified}") - return is_verified - except Exception as e: - logger.warning(f"Error verifying push notification URL {url}: {e}") - - return False - - def generate_jwk(self): - """ - Generates a new JWK pair for signing. - - The key pair is used to sign push notifications. - The public key is available via the JWKS endpoint. - """ - key = jwk.JWK.generate(kty="RSA", size=2048, kid=str(uuid.uuid4()), use="sig") - self.public_keys.append(key.export_public(as_dict=True)) - self.private_key_jwk = PyJWK.from_json(key.export_private()) - - def handle_jwks_endpoint(self, _request: Request) -> JSONResponse: - """ - Handles the JWKS endpoint to allow clients to obtain the public keys. - - Args: - _request: HTTP request - - Returns: - JSON response with the public keys - """ - return JSONResponse({"keys": self.public_keys}) - - def _generate_jwt(self, data: Dict[str, Any]) -> str: - """ - Generates a JWT token by signing the SHA256 hash of the payload and the timestamp. - - The payload is signed with the private key to ensure integrity. - The timestamp (iat) prevents replay attacks. - - Args: - data: Payload data - - Returns: - Signed JWT token - """ - iat = int(time.time()) - - return jwt.encode( - { - "iat": iat, - "request_body_sha256": self._calculate_request_body_sha256(data), - }, - key=self.private_key_jwk.key, - headers={"kid": self.private_key_jwk.key_id}, - algorithm="RS256", - ) - - async def send_push_notification(self, url: str, data: Dict[str, Any]) -> bool: - """ - Sends an authenticated push notification to the specified URL. - - Args: - url: URL to send the notification - data: Notification data - - Returns: - True if the notification was sent successfully, False otherwise - """ - if not self.private_key_jwk: - logger.error( - "Attempt to send push notification without generating JWK keys" - ) - return False - - try: - jwt_token = self._generate_jwt(data) - headers = {"Authorization": f"Bearer {jwt_token}"} - - async with httpx.AsyncClient(timeout=10) as client: - response = await client.post(url, json=data, headers=headers) - response.raise_for_status() - logger.info(f"Push notification sent to URL: {url}") - return True - except Exception as e: - logger.warning(f"Error sending push notification to URL {url}: {e}") - return False - - -class PushNotificationReceiverAuth(PushNotificationAuth): - """ - Authentication for the push notification receiver. - - This class is used by clients to verify the authenticity - of notifications received from the A2A server. - """ - - def __init__(self): - """ - Initialize the push notification receiver authentication service. - """ - self.jwks_client = None - - async def load_jwks(self, jwks_url: str): - """ - Loads the public JWKS keys from a URL. - - Args: - jwks_url: URL of the JWKS endpoint - """ - self.jwks_client = PyJWKClient(jwks_url) - - async def verify_push_notification(self, request: Request) -> bool: - """ - Verifies the authenticity of a push notification. - - Args: - request: HTTP request containing the notification - - Returns: - True if the notification is authentic, False otherwise - - Raises: - ValueError: If the token is invalid or expired - """ - if not self.jwks_client: - logger.error("Attempt to verify notification without loading JWKS keys") - return False - - # Verify authentication header - auth_header = request.headers.get("Authorization") - if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX): - logger.warning("Invalid authorization header") - return False - - try: - # Extract and verify token - token = auth_header[len(AUTH_HEADER_PREFIX) :] - signing_key = self.jwks_client.get_signing_key_from_jwt(token) - - # Decode token - decode_token = jwt.decode( - token, - signing_key.key, - options={"require": ["iat", "request_body_sha256"]}, - algorithms=["RS256"], - ) - - # Verify request body integrity - body_data = await request.json() - actual_body_sha256 = self._calculate_request_body_sha256(body_data) - if actual_body_sha256 != decode_token["request_body_sha256"]: - # The payload signature does not match the hash in the signed token - logger.warning("Request body hash does not match the token") - raise ValueError("Invalid request body") - - # Verify token age (maximum 5 minutes) - if time.time() - decode_token["iat"] > 60 * 5: - # Do not allow notifications older than 5 minutes - # This prevents replay attacks - logger.warning("Token expired") - raise ValueError("Token expired") - - return True - - except Exception as e: - logger.error(f"Error verifying push notification: {e}") - return False - - -# Global instance of the push notification sender authentication service -push_notification_auth = PushNotificationSenderAuth() - -# Generate JWK keys on initialization -push_notification_auth.generate_jwk() diff --git a/src/services/push_notification_service.py b/src/services/push_notification_service.py deleted file mode 100644 index 999ba567..00000000 --- a/src/services/push_notification_service.py +++ /dev/null @@ -1,99 +0,0 @@ -import aiohttp -import logging -from datetime import datetime -from typing import Dict, Any, Optional -import asyncio - -from src.services.push_notification_auth_service import push_notification_auth - -logger = logging.getLogger(__name__) - - -class PushNotificationService: - def __init__(self): - self.session = aiohttp.ClientSession() - - async def send_notification( - self, - url: str, - task_id: str, - state: str, - message: Optional[Dict[str, Any]] = None, - headers: Optional[Dict[str, str]] = None, - max_retries: int = 3, - retry_delay: float = 1.0, - use_jwt_auth: bool = True, - ) -> bool: - """ - Send a push notification to the specified URL. - Implements exponential backoff retry. - - Args: - url: URL to send the notification - task_id: Task ID - state: Task state - message: Optional message - headers: Optional HTTP headers - max_retries: Maximum number of retries - retry_delay: Initial delay between retries - use_jwt_auth: Whether to use JWT authentication - - Returns: - True if the notification was sent successfully, False otherwise - """ - payload = { - "taskId": task_id, - "state": state, - "timestamp": datetime.now().isoformat(), - "message": message, - } - - # First URL verification - if use_jwt_auth: - is_url_valid = await push_notification_auth.verify_push_notification_url( - url - ) - if not is_url_valid: - logger.warning(f"Invalid push notification URL: {url}") - return False - - for attempt in range(max_retries): - try: - if use_jwt_auth: - # Use JWT authentication - success = await push_notification_auth.send_push_notification( - url, payload - ) - if success: - return True - else: - # Legacy method without JWT authentication - async with self.session.post( - url, json=payload, headers=headers or {}, timeout=10 - ) as response: - if response.status in (200, 201, 202, 204): - logger.info(f"Push notification sent to URL: {url}") - return True - else: - logger.warning( - f"Failed to send push notification with status {response.status}. " - f"Attempt {attempt + 1}/{max_retries}" - ) - except Exception as e: - logger.error( - f"Error sending push notification: {str(e)}. " - f"Attempt {attempt + 1}/{max_retries}" - ) - - if attempt < max_retries - 1: - await asyncio.sleep(retry_delay * (2**attempt)) - - return False - - async def close(self): - """Close the HTTP session""" - await self.session.close() - - -# Global instance of the push notification service -push_notification_service = PushNotificationService() diff --git a/src/services/redis_cache_service.py b/src/services/redis_cache_service.py deleted file mode 100644 index 72ed0ce2..00000000 --- a/src/services/redis_cache_service.py +++ /dev/null @@ -1,555 +0,0 @@ -""" -Cache Redis service for the A2A protocol. - -This service provides an interface for storing and retrieving data related to tasks, -push notification configurations, and other A2A-related data. -""" - -import json -import logging -from typing import Any, Dict, List, Optional -import asyncio -import redis.asyncio as aioredis -from src.config.redis import get_redis_config -import threading -import time - -logger = logging.getLogger(__name__) - - -class _InMemoryCacheFallback: - """ - Fallback in-memory cache implementation for when Redis is not available. - - This should only be used for development or testing environments. - """ - - _instance = None - _lock = threading.Lock() - - def __new__(cls): - """Singleton implementation.""" - if cls._instance is None: - with cls._lock: - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._initialized = False - return cls._instance - - def __init__(self): - """Initialize cache storage.""" - if not getattr(self, "_initialized", False): - with self._lock: - if not getattr(self, "_initialized", False): - self._data = {} - self._ttls = {} - self._hash_data = {} - self._list_data = {} - self._data_lock = threading.Lock() - self._initialized = True - logger.warning( - "Initializing in-memory cache fallback (not for production)" - ) - - async def set(self, key, value, ex=None): - """Set a key with optional expiration.""" - with self._data_lock: - self._data[key] = value - if ex is not None: - self._ttls[key] = time.time() + ex - elif key in self._ttls: - del self._ttls[key] - return True - - async def setex(self, key, ex, value): - """Set a key with expiration.""" - return await self.set(key, value, ex) - - async def get(self, key): - """Get a key value.""" - with self._data_lock: - # Check if expired - if key in self._ttls and time.time() > self._ttls[key]: - del self._data[key] - del self._ttls[key] - return None - return self._data.get(key) - - async def delete(self, key): - """Delete a key.""" - with self._data_lock: - if key in self._data: - del self._data[key] - if key in self._ttls: - del self._ttls[key] - return 1 - return 0 - - async def exists(self, key): - """Check if key exists.""" - with self._data_lock: - if key in self._ttls and time.time() > self._ttls[key]: - del self._data[key] - del self._ttls[key] - return 0 - return 1 if key in self._data else 0 - - async def hset(self, key, field, value): - """Set a hash field.""" - with self._data_lock: - if key not in self._hash_data: - self._hash_data[key] = {} - self._hash_data[key][field] = value - return 1 - - async def hget(self, key, field): - """Get a hash field.""" - with self._data_lock: - if key not in self._hash_data: - return None - return self._hash_data[key].get(field) - - async def hdel(self, key, field): - """Delete a hash field.""" - with self._data_lock: - if key in self._hash_data and field in self._hash_data[key]: - del self._hash_data[key][field] - return 1 - return 0 - - async def hgetall(self, key): - """Get all hash fields.""" - with self._data_lock: - if key not in self._hash_data: - return {} - return dict(self._hash_data[key]) - - async def rpush(self, key, value): - """Add to a list.""" - with self._data_lock: - if key not in self._list_data: - self._list_data[key] = [] - self._list_data[key].append(value) - return len(self._list_data[key]) - - async def lrange(self, key, start, end): - """Get range from list.""" - with self._data_lock: - if key not in self._list_data: - return [] - - # Handle negative indices - if end < 0: - end = len(self._list_data[key]) + end + 1 - - return self._list_data[key][start:end] - - async def expire(self, key, seconds): - """Set expiration on key.""" - with self._data_lock: - if key in self._data: - self._ttls[key] = time.time() + seconds - return 1 - return 0 - - async def flushdb(self): - """Clear all data.""" - with self._data_lock: - self._data.clear() - self._ttls.clear() - self._hash_data.clear() - self._list_data.clear() - return True - - async def keys(self, pattern="*"): - """Get keys matching pattern.""" - with self._data_lock: - # Clean expired keys - now = time.time() - expired_keys = [k for k, exp in self._ttls.items() if now > exp] - for k in expired_keys: - if k in self._data: - del self._data[k] - del self._ttls[k] - - # Simple pattern matching - result = [] - if pattern == "*": - result = list(self._data.keys()) - elif pattern.endswith("*"): - prefix = pattern[:-1] - result = [k for k in self._data.keys() if k.startswith(prefix)] - elif pattern.startswith("*"): - suffix = pattern[1:] - result = [k for k in self._data.keys() if k.endswith(suffix)] - else: - if pattern in self._data: - result = [pattern] - - return result - - async def ping(self): - """Test connection.""" - return True - - -class RedisCacheService: - """ - Cache service using Redis for storing A2A-related data. - - This implementation uses a real Redis connection for distributed caching - and data persistence across multiple instances. - - If Redis is not available, falls back to an in-memory implementation. - """ - - def __init__(self, redis_url: Optional[str] = None): - """ - Initialize the Redis cache service. - - Args: - redis_url: Redis server URL (optional, defaults to config value) - """ - if redis_url: - self._redis_url = redis_url - else: - # Construir URL a partir dos componentes de configuração - config = get_redis_config() - protocol = "rediss" if config.get("ssl", False) else "redis" - auth = f":{config['password']}@" if config.get("password") else "" - self._redis_url = ( - f"{protocol}://{auth}{config['host']}:{config['port']}/{config['db']}" - ) - - self._redis = None - self._in_memory_mode = False - self._connecting = False - self._connection_lock = asyncio.Lock() - logger.info(f"Initializing RedisCacheService with URL: {self._redis_url}") - - async def _get_redis(self): - """ - Get a Redis connection, creating it if necessary. - Falls back to in-memory implementation if Redis is not available. - - Returns: - Redis connection or in-memory fallback - """ - if self._redis is not None: - return self._redis - - async with self._connection_lock: - if self._redis is None and not self._connecting: - try: - self._connecting = True - logger.info(f"Connecting to Redis at {self._redis_url}") - self._redis = aioredis.from_url( - self._redis_url, encoding="utf-8", decode_responses=True - ) - # Teste de conexão - await self._redis.ping() - logger.info("Redis connection successful") - except Exception as e: - logger.error(f"Error connecting to Redis: {str(e)}") - logger.warning( - "Falling back to in-memory cache (not suitable for production)" - ) - self._redis = _InMemoryCacheFallback() - self._in_memory_mode = True - finally: - self._connecting = False - - return self._redis - - async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: - """ - Store a value in the cache. - - Args: - key: Key for the value - value: Value to store - ttl: Time to live in seconds (optional) - """ - try: - redis = await self._get_redis() - - # Convert dict/list to JSON string - if isinstance(value, (dict, list)): - value = json.dumps(value) - - if ttl: - await redis.setex(key, ttl, value) - else: - await redis.set(key, value) - - logger.debug(f"Set cache key: {key}") - except Exception as e: - logger.error(f"Error setting Redis key {key}: {str(e)}") - - async def get(self, key: str) -> Optional[Any]: - """ - Retrieve a value from the cache. - - Args: - key: Key for the value to retrieve - - Returns: - The stored value or None if not found - """ - try: - redis = await self._get_redis() - value = await redis.get(key) - - if value is None: - return None - - try: - # Try to parse as JSON - return json.loads(value) - except json.JSONDecodeError: - # Return as-is if not JSON - return value - - except Exception as e: - logger.error(f"Error getting Redis key {key}: {str(e)}") - return None - - async def delete(self, key: str) -> bool: - """ - Remove a value from the cache. - - Args: - key: Key for the value to remove - - Returns: - True if the value was removed, False if it didn't exist - """ - try: - redis = await self._get_redis() - result = await redis.delete(key) - return result > 0 - except Exception as e: - logger.error(f"Error deleting Redis key {key}: {str(e)}") - return False - - async def exists(self, key: str) -> bool: - """ - Check if a key exists in the cache. - - Args: - key: Key to check - - Returns: - True if the key exists, False otherwise - """ - try: - redis = await self._get_redis() - return await redis.exists(key) > 0 - except Exception as e: - logger.error(f"Error checking Redis key {key}: {str(e)}") - return False - - async def set_hash(self, key: str, field: str, value: Any) -> None: - """ - Store a value in a hash. - - Args: - key: Hash key - field: Hash field - value: Value to store - """ - try: - redis = await self._get_redis() - - # Convert dict/list to JSON string - if isinstance(value, (dict, list)): - value = json.dumps(value) - - await redis.hset(key, field, value) - logger.debug(f"Set hash field: {key}:{field}") - except Exception as e: - logger.error(f"Error setting Redis hash {key}:{field}: {str(e)}") - - async def get_hash(self, key: str, field: str) -> Optional[Any]: - """ - Retrieve a value from a hash. - - Args: - key: Hash key - field: Hash field - - Returns: - The stored value or None if not found - """ - try: - redis = await self._get_redis() - value = await redis.hget(key, field) - - if value is None: - return None - - try: - # Try to parse as JSON - return json.loads(value) - except json.JSONDecodeError: - # Return as-is if not JSON - return value - - except Exception as e: - logger.error(f"Error getting Redis hash {key}:{field}: {str(e)}") - return None - - async def delete_hash(self, key: str, field: str) -> bool: - """ - Remove a value from a hash. - - Args: - key: Hash key - field: Hash field - - Returns: - True if the value was removed, False if it didn't exist - """ - try: - redis = await self._get_redis() - result = await redis.hdel(key, field) - return result > 0 - except Exception as e: - logger.error(f"Error deleting Redis hash {key}:{field}: {str(e)}") - return False - - async def get_all_hash(self, key: str) -> Dict[str, Any]: - """ - Retrieve all values from a hash. - - Args: - key: Hash key - - Returns: - Dictionary with all hash values - """ - try: - redis = await self._get_redis() - result_dict = await redis.hgetall(key) - - if not result_dict: - return {} - - # Try to parse each value as JSON - parsed_dict = {} - for field, value in result_dict.items(): - try: - parsed_dict[field] = json.loads(value) - except json.JSONDecodeError: - parsed_dict[field] = value - - return parsed_dict - - except Exception as e: - logger.error(f"Error getting all Redis hash fields for {key}: {str(e)}") - return {} - - async def push_list(self, key: str, value: Any) -> int: - """ - Add a value to the end of a list. - - Args: - key: List key - value: Value to add - - Returns: - Size of the list after the addition - """ - try: - redis = await self._get_redis() - - # Convert dict/list to JSON string - if isinstance(value, (dict, list)): - value = json.dumps(value) - - return await redis.rpush(key, value) - except Exception as e: - logger.error(f"Error pushing to Redis list {key}: {str(e)}") - return 0 - - async def get_list(self, key: str, start: int = 0, end: int = -1) -> List[Any]: - """ - Retrieve values from a list. - - Args: - key: List key - start: Initial index (inclusive) - end: Final index (inclusive), -1 for all - - Returns: - List with the retrieved values - """ - try: - redis = await self._get_redis() - values = await redis.lrange(key, start, end) - - if not values: - return [] - - # Try to parse each value as JSON - result = [] - for value in values: - try: - result.append(json.loads(value)) - except json.JSONDecodeError: - result.append(value) - - return result - - except Exception as e: - logger.error(f"Error getting Redis list {key}: {str(e)}") - return [] - - async def expire(self, key: str, ttl: int) -> bool: - """ - Set a time-to-live for a key. - - Args: - key: Key - ttl: Time-to-live in seconds - - Returns: - True if the key exists and the TTL was set, False otherwise - """ - try: - redis = await self._get_redis() - return await redis.expire(key, ttl) - except Exception as e: - logger.error(f"Error setting expire for Redis key {key}: {str(e)}") - return False - - async def clear(self) -> None: - """ - Clear the entire cache. - - Warning: This is a destructive operation and will remove all data. - Only use in development/test environments. - """ - try: - redis = await self._get_redis() - await redis.flushdb() - logger.warning("Redis database flushed - all data cleared") - except Exception as e: - logger.error(f"Error clearing Redis database: {str(e)}") - - async def keys(self, pattern: str = "*") -> List[str]: - """ - Retrieve keys that match a pattern. - - Args: - pattern: Glob pattern to filter keys - - Returns: - List of keys that match the pattern - """ - try: - redis = await self._get_redis() - return await redis.keys(pattern) - except Exception as e: - logger.error(f"Error getting Redis keys with pattern {pattern}: {str(e)}") - return [] diff --git a/src/services/streaming_service.py b/src/services/streaming_service.py deleted file mode 100644 index 4cee4455..00000000 --- a/src/services/streaming_service.py +++ /dev/null @@ -1,141 +0,0 @@ -import uuid -import json -from typing import AsyncGenerator, Dict, Any -from fastapi import HTTPException -from datetime import datetime -from src.schemas.streaming import ( - JSONRPCRequest, - TaskStatusUpdateEvent, -) -from src.services.agent_runner import run_agent -from src.services.service_providers import ( - session_service, - artifacts_service, - memory_service, -) -from sqlalchemy.orm import Session - - -class StreamingService: - def __init__(self): - self.active_connections: Dict[str, Any] = {} - - async def send_task_streaming( - self, - agent_id: str, - api_key: str, - message: str, - contact_id: str = None, - session_id: str = None, - db: Session = None, - ) -> AsyncGenerator[str, None]: - """ - Starts the SSE event streaming for a task. - - Args: - agent_id: Agent ID - api_key: API key for authentication - message: Initial message - contact_id: Contact ID (optional) - session_id: Session ID (optional) - db: Database session - - Yields: - Formatted SSE events - """ - # Basic API key validation - if not api_key: - raise HTTPException(status_code=401, detail="API key is required") - - # Generate unique IDs - task_id = contact_id or str(uuid.uuid4()) - request_id = str(uuid.uuid4()) - - # Build JSON-RPC payload - payload = JSONRPCRequest( - id=request_id, - params={ - "id": task_id, - "sessionId": session_id, - "message": { - "role": "user", - "parts": [{"type": "text", "text": message}], - }, - }, - ) - - # Register connection - self.active_connections[task_id] = { - "agent_id": agent_id, - "api_key": api_key, - "session_id": session_id, - } - - try: - # Send start event - yield self._format_sse_event( - "status", - TaskStatusUpdateEvent( - state="working", - timestamp=datetime.now().isoformat(), - message=payload.params["message"], - ).model_dump_json(), - ) - - # Execute the agent - result = await run_agent( - str(agent_id), - contact_id or task_id, - message, - session_service, - artifacts_service, - memory_service, - db, - session_id, - ) - - # Send the agent's response as a separate event - yield self._format_sse_event( - "message", - json.dumps( - { - "role": "agent", - "content": result, - "timestamp": datetime.now().isoformat(), - } - ), - ) - - # Completion event - yield self._format_sse_event( - "status", - TaskStatusUpdateEvent( - state="completed", - timestamp=datetime.now().isoformat(), - ).model_dump_json(), - ) - - except Exception as e: - # Error event - yield self._format_sse_event( - "status", - TaskStatusUpdateEvent( - state="failed", - timestamp=datetime.now().isoformat(), - error={"message": str(e)}, - ).model_dump_json(), - ) - raise - - finally: - # Clean connection - self.active_connections.pop(task_id, None) - - def _format_sse_event(self, event_type: str, data: str) -> str: - """Format an SSE event.""" - return f"event: {event_type}\ndata: {data}\n\n" - - async def close_connection(self, task_id: str): - """Close a streaming connection.""" - if task_id in self.active_connections: - self.active_connections.pop(task_id) diff --git a/src/utils/a2a_utils.py b/src/utils/a2a_utils.py index e2a93e4b..1c6594ac 100644 --- a/src/utils/a2a_utils.py +++ b/src/utils/a2a_utils.py @@ -1,110 +1,28 @@ -""" -A2A protocol utilities. - -This module contains utility functions for the A2A protocol implementation. -""" - -import logging -from typing import List, Optional, Any, Dict -from src.schemas.a2a import ( +from src.schemas.a2a_types import ( ContentTypeNotSupportedError, - UnsupportedOperationError, JSONRPCResponse, + UnsupportedOperationError, ) -logger = logging.getLogger(__name__) - def are_modalities_compatible( - server_output_modes: Optional[List[str]], client_output_modes: Optional[List[str]] -) -> bool: - """ - Check if server and client modalities are compatible. - - Modalities are compatible if they are both non-empty + server_output_modes: list[str], client_output_modes: list[str] +): + """Modalities are compatible if they are both non-empty and there is at least one common element. - - Args: - server_output_modes: List of output modes supported by the server - client_output_modes: List of output modes requested by the client - - Returns: - True if compatible, False otherwise """ - # If client doesn't specify modes, assume all are accepted if client_output_modes is None or len(client_output_modes) == 0: return True - # If server doesn't specify modes, assume all are supported if server_output_modes is None or len(server_output_modes) == 0: return True - # Check if there's at least one common mode - return any(mode in server_output_modes for mode in client_output_modes) + return any(x in server_output_modes for x in client_output_modes) -def new_incompatible_types_error(request_id: str) -> JSONRPCResponse: - """ - Create a JSON-RPC response for incompatible content types error. - - Args: - request_id: The ID of the request that caused the error - - Returns: - JSON-RPC response with ContentTypeNotSupportedError - """ +def new_incompatible_types_error(request_id): return JSONRPCResponse(id=request_id, error=ContentTypeNotSupportedError()) -def new_not_implemented_error(request_id: str) -> JSONRPCResponse: - """ - Create a JSON-RPC response for unimplemented operation error. - - Args: - request_id: The ID of the request that caused the error - - Returns: - JSON-RPC response with UnsupportedOperationError - """ +def new_not_implemented_error(request_id): return JSONRPCResponse(id=request_id, error=UnsupportedOperationError()) - - -def create_task_id(agent_id: str, session_id: str, timestamp: str = None) -> str: - """ - Create a unique task ID for an agent and session. - - Args: - agent_id: The ID of the agent - session_id: The ID of the session - timestamp: Optional timestamp to include in the ID - - Returns: - A unique task ID - """ - import uuid - import time - - timestamp = timestamp or str(int(time.time())) - unique = uuid.uuid4().hex[:8] - - return f"{agent_id}_{session_id}_{timestamp}_{unique}" - - -def format_error_response(error: Exception, request_id: str = None) -> Dict[str, Any]: - """ - Format an exception as a JSON-RPC error response. - - Args: - error: The exception to format - request_id: The ID of the request that caused the error - - Returns: - JSON-RPC error response as dictionary - """ - from src.schemas.a2a import InternalError, JSONRPCResponse - - error_response = JSONRPCResponse( - id=request_id, error=InternalError(message=str(error)) - ) - - return error_response.model_dump(exclude_none=True)