refactor(a2a): restructure A2A routes and services for improved task management and API key verification
This commit is contained in:
parent
a0f984ae21
commit
ec9dc07d71
@ -6,29 +6,18 @@ This module implements the standard A2A routes according to the specification.
|
||||
|
||||
import uuid
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Header, Request
|
||||
import json
|
||||
from fastapi import APIRouter, Depends, Header, Request, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import JSONResponse
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from src.models.models import Agent
|
||||
from src.config.database import get_db
|
||||
from src.services import agent_service
|
||||
from src.services import (
|
||||
RedisCacheService,
|
||||
AgentRunnerAdapter,
|
||||
StreamingServiceAdapter,
|
||||
create_agent_card_from_agent,
|
||||
from src.services.a2a_task_manager import (
|
||||
A2ATaskManager,
|
||||
A2AService,
|
||||
)
|
||||
from src.services.a2a_task_manager_service import A2ATaskManager
|
||||
from src.services.a2a_server_service import A2AServer
|
||||
from src.services.agent_runner import run_agent
|
||||
from src.services.service_providers import (
|
||||
session_service,
|
||||
artifacts_service,
|
||||
memory_service,
|
||||
)
|
||||
from src.services.push_notification_service import push_notification_service
|
||||
from src.services.push_notification_auth_service import push_notification_auth
|
||||
from src.services.streaming_service import StreamingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -43,85 +32,27 @@ router = APIRouter(
|
||||
},
|
||||
)
|
||||
|
||||
streaming_service = StreamingService()
|
||||
redis_cache_service = RedisCacheService()
|
||||
streaming_adapter = StreamingServiceAdapter(streaming_service)
|
||||
|
||||
_task_manager_cache = {}
|
||||
_agent_runner_cache = {}
|
||||
def get_a2a_service(db: Session = Depends(get_db)):
|
||||
task_manager = A2ATaskManager(db)
|
||||
return A2AService(db, task_manager)
|
||||
|
||||
|
||||
def get_agent_runner_adapter(db=None, reuse=True, agent_id=None):
|
||||
"""
|
||||
Get or create an agent runner adapter.
|
||||
async def verify_api_key(db: Session, x_api_key: str) -> bool:
|
||||
"""Verifies the API key."""
|
||||
if not x_api_key:
|
||||
raise HTTPException(status_code=401, detail="API key not provided")
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
reuse: Whether to reuse an existing instance
|
||||
agent_id: Agent ID to use as cache key
|
||||
|
||||
Returns:
|
||||
Agent runner adapter instance
|
||||
"""
|
||||
cache_key = str(agent_id) if agent_id else "default"
|
||||
|
||||
if reuse and cache_key in _agent_runner_cache:
|
||||
adapter = _agent_runner_cache[cache_key]
|
||||
|
||||
if db is not None:
|
||||
adapter.db = db
|
||||
return adapter
|
||||
|
||||
adapter = AgentRunnerAdapter(
|
||||
agent_runner_func=run_agent,
|
||||
session_service=session_service,
|
||||
artifacts_service=artifacts_service,
|
||||
memory_service=memory_service,
|
||||
db=db,
|
||||
agent = (
|
||||
db.query(Agent)
|
||||
.filter(Agent.config.has_key("api_key"))
|
||||
.filter(Agent.config["api_key"].astext == x_api_key)
|
||||
.first()
|
||||
)
|
||||
|
||||
if reuse:
|
||||
_agent_runner_cache[cache_key] = adapter
|
||||
|
||||
return adapter
|
||||
|
||||
|
||||
def get_task_manager(agent_id, db=None, reuse=True, operation_type="query"):
|
||||
cache_key = str(agent_id)
|
||||
|
||||
if operation_type == "query":
|
||||
if cache_key in _task_manager_cache:
|
||||
|
||||
task_manager = _task_manager_cache[cache_key]
|
||||
task_manager.db = db
|
||||
return task_manager
|
||||
|
||||
return A2ATaskManager(
|
||||
redis_cache=redis_cache_service,
|
||||
agent_runner=None,
|
||||
streaming_service=streaming_adapter,
|
||||
push_notification_service=push_notification_service,
|
||||
db=db,
|
||||
)
|
||||
|
||||
if reuse and cache_key in _task_manager_cache:
|
||||
task_manager = _task_manager_cache[cache_key]
|
||||
task_manager.db = db
|
||||
return task_manager
|
||||
|
||||
# Create new
|
||||
agent_runner_adapter = get_agent_runner_adapter(
|
||||
db=db, reuse=reuse, agent_id=agent_id
|
||||
)
|
||||
task_manager = A2ATaskManager(
|
||||
redis_cache=redis_cache_service,
|
||||
agent_runner=agent_runner_adapter,
|
||||
streaming_service=streaming_adapter,
|
||||
push_notification_service=push_notification_service,
|
||||
db=db,
|
||||
)
|
||||
_task_manager_cache[cache_key] = task_manager
|
||||
return task_manager
|
||||
if not agent:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
return True
|
||||
|
||||
|
||||
@router.post("/{agent_id}")
|
||||
@ -130,156 +61,42 @@ async def process_a2a_request(
|
||||
request: Request,
|
||||
x_api_key: str = Header(None, alias="x-api-key"),
|
||||
db: Session = Depends(get_db),
|
||||
a2a_service: A2AService = Depends(get_a2a_service),
|
||||
):
|
||||
"""
|
||||
Main endpoint for processing JSON-RPC requests of the A2A protocol.
|
||||
"""Processes an A2A request."""
|
||||
# Verify the API key
|
||||
if not verify_api_key(db, x_api_key):
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
This endpoint processes all JSON-RPC methods of the A2A protocol, including:
|
||||
- tasks/send: Sending tasks
|
||||
- tasks/sendSubscribe: Sending tasks with streaming
|
||||
- tasks/get: Querying task status
|
||||
- tasks/cancel: Cancelling tasks
|
||||
- tasks/pushNotification/set: Setting push notifications
|
||||
- tasks/pushNotification/get: Querying push notification configurations
|
||||
- tasks/resubscribe: Resubscribing to receive task updates
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
request: HTTP request with JSON-RPC payload
|
||||
x_api_key: API key for authentication
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
JSON-RPC response or streaming (SSE) depending on the method
|
||||
"""
|
||||
# Process the request
|
||||
try:
|
||||
try:
|
||||
body = await request.json()
|
||||
method = body.get("method", "unknown")
|
||||
request_id = body.get("id") # Extract request ID to ensure it's preserved
|
||||
request_body = await request.json()
|
||||
result = await a2a_service.process_request(agent_id, request_body)
|
||||
|
||||
# Extrair sessionId do params se for tasks/send
|
||||
session_id = None
|
||||
if method == "tasks/send" and body.get("params"):
|
||||
session_id = body.get("params", {}).get("sessionId")
|
||||
logger.info(f"Extracted sessionId from request: {session_id}")
|
||||
# If the response is a streaming response, return as EventSourceResponse
|
||||
if hasattr(result, "__aiter__"):
|
||||
|
||||
is_query_request = method in [
|
||||
"tasks/get",
|
||||
"tasks/cancel",
|
||||
"tasks/pushNotification/get",
|
||||
"tasks/resubscribe",
|
||||
]
|
||||
async def event_generator():
|
||||
async for item in result:
|
||||
if hasattr(item, "model_dump_json"):
|
||||
yield {"data": item.model_dump_json(exclude_none=True)}
|
||||
else:
|
||||
yield {"data": json.dumps(item)}
|
||||
|
||||
reuse_components = is_query_request
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading request body: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": None,
|
||||
"error": {
|
||||
"code": -32700,
|
||||
"message": f"Parse error: {str(e)}",
|
||||
"data": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Verify if the agent exists
|
||||
agent = agent_service.get_agent(db, agent_id)
|
||||
if agent is None:
|
||||
logger.warning(f"Agent not found: {agent_id}")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id, # Use the extracted request ID
|
||||
"error": {"code": 404, "message": "Agent not found", "data": None},
|
||||
},
|
||||
)
|
||||
|
||||
# Verify API key
|
||||
agent_config = agent.config
|
||||
|
||||
if x_api_key and agent_config.get("api_key") != x_api_key:
|
||||
logger.warning(f"Invalid API Key for agent {agent_id}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id, # Use the extracted request ID
|
||||
"error": {"code": 401, "message": "Invalid API key", "data": None},
|
||||
},
|
||||
)
|
||||
|
||||
a2a_task_manager = get_task_manager(
|
||||
agent_id,
|
||||
db=db,
|
||||
reuse=reuse_components,
|
||||
operation_type="query" if is_query_request else "execution",
|
||||
)
|
||||
a2a_server = A2AServer(task_manager=a2a_task_manager)
|
||||
|
||||
agent_card = create_agent_card_from_agent(agent, db)
|
||||
a2a_server.agent_card = agent_card
|
||||
|
||||
# Verify JSON-RPC format
|
||||
if not body.get("jsonrpc") or body.get("jsonrpc") != "2.0":
|
||||
logger.error(f"Invalid JSON-RPC format: {body.get('jsonrpc')}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id, # Use the extracted request ID
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": "Invalid Request: jsonrpc must be '2.0'",
|
||||
"data": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if not body.get("method"):
|
||||
logger.error("Method not specified in request")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id, # Use the extracted request ID
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": "Invalid Request: method is required",
|
||||
"data": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Processar a requisição normalmente
|
||||
return await a2a_server.process_request(request, agent_id=str(agent_id), db=db)
|
||||
return EventSourceResponse(event_generator())
|
||||
|
||||
# Otherwise, return as JSONResponse
|
||||
if hasattr(result, "model_dump"):
|
||||
return JSONResponse(result.model_dump(exclude_none=True))
|
||||
return JSONResponse(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing A2A request: {str(e)}", exc_info=True)
|
||||
# Try to extract request ID from the body, if available
|
||||
request_id = None
|
||||
try:
|
||||
body = await request.json()
|
||||
request_id = body.get("id")
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.error(f"Error processing A2A request: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id, # Use the extracted request ID or None
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": "Internal server error",
|
||||
"data": {"detail": str(e)},
|
||||
},
|
||||
"id": None,
|
||||
"error": {"code": -32603, "message": "Internal server error"},
|
||||
},
|
||||
)
|
||||
|
||||
@ -289,81 +106,17 @@ async def get_agent_card(
|
||||
agent_id: uuid.UUID,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
a2a_service: A2AService = Depends(get_a2a_service),
|
||||
):
|
||||
"""
|
||||
Endpoint to get the Agent Card in the .well-known format of the A2A protocol.
|
||||
|
||||
This endpoint returns the agent information in the standard A2A format,
|
||||
including capabilities, authentication information, and skills.
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
request: HTTP request
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Agent Card in JSON format
|
||||
"""
|
||||
"""Gets the agent card for the specified agent."""
|
||||
try:
|
||||
agent = agent_service.get_agent(db, agent_id)
|
||||
if agent is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found"
|
||||
)
|
||||
|
||||
agent_card = create_agent_card_from_agent(agent, db)
|
||||
|
||||
a2a_task_manager = get_task_manager(agent_id, db=db, reuse=True)
|
||||
a2a_server = A2AServer(task_manager=a2a_task_manager)
|
||||
|
||||
# Configure the A2A server with the agent card
|
||||
a2a_server.agent_card = agent_card
|
||||
|
||||
# Use the A2A server to deliver the agent card, ensuring protocol compatibility
|
||||
return await a2a_server.get_agent_card(request, db=db)
|
||||
|
||||
agent_card = a2a_service.get_agent_card(agent_id)
|
||||
if hasattr(agent_card, "model_dump"):
|
||||
return JSONResponse(agent_card.model_dump(exclude_none=True))
|
||||
return JSONResponse(agent_card)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating agent card: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error generating agent card",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/.well-known/jwks.json")
|
||||
async def get_jwks(
|
||||
agent_id: uuid.UUID,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Endpoint to get the public JWKS keys for verifying the authenticity
|
||||
of push notifications.
|
||||
|
||||
Clients can use these keys to verify the authenticity of received notifications.
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
request: HTTP request
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
JSON with the public keys in JWKS format
|
||||
"""
|
||||
try:
|
||||
# Verify if the agent exists
|
||||
agent = agent_service.get_agent(db, agent_id)
|
||||
if agent is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found"
|
||||
)
|
||||
|
||||
# Return the public keys
|
||||
return push_notification_auth.handle_jwks_endpoint(request)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error obtaining JWKS: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error obtaining JWKS",
|
||||
logger.error(f"Error getting agent card: {e}")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={"error": f"Agent not found: {str(e)}"},
|
||||
)
|
||||
|
@ -1,140 +0,0 @@
|
||||
"""
|
||||
A2A (Agent-to-Agent) schema package.
|
||||
|
||||
This package contains Pydantic schema definitions for the A2A protocol.
|
||||
"""
|
||||
|
||||
from src.schemas.a2a.types import (
|
||||
TaskState,
|
||||
TextPart,
|
||||
FileContent,
|
||||
FilePart,
|
||||
DataPart,
|
||||
Part,
|
||||
Message,
|
||||
TaskStatus,
|
||||
Artifact,
|
||||
Task,
|
||||
TaskStatusUpdateEvent,
|
||||
TaskArtifactUpdateEvent,
|
||||
AuthenticationInfo,
|
||||
PushNotificationConfig,
|
||||
TaskIdParams,
|
||||
TaskQueryParams,
|
||||
TaskSendParams,
|
||||
TaskPushNotificationConfig,
|
||||
JSONRPCMessage,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
JSONRPCError,
|
||||
SendTaskRequest,
|
||||
SendTaskResponse,
|
||||
SendTaskStreamingRequest,
|
||||
SendTaskStreamingResponse,
|
||||
GetTaskRequest,
|
||||
GetTaskResponse,
|
||||
CancelTaskRequest,
|
||||
CancelTaskResponse,
|
||||
SetTaskPushNotificationRequest,
|
||||
SetTaskPushNotificationResponse,
|
||||
GetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationResponse,
|
||||
TaskResubscriptionRequest,
|
||||
A2ARequest,
|
||||
AgentProvider,
|
||||
AgentCapabilities,
|
||||
AgentAuthentication,
|
||||
AgentSkill,
|
||||
AgentCard,
|
||||
)
|
||||
|
||||
from src.schemas.a2a.exceptions import (
|
||||
JSONParseError,
|
||||
InvalidRequestError,
|
||||
MethodNotFoundError,
|
||||
InvalidParamsError,
|
||||
InternalError,
|
||||
TaskNotFoundError,
|
||||
TaskNotCancelableError,
|
||||
PushNotificationNotSupportedError,
|
||||
UnsupportedOperationError,
|
||||
ContentTypeNotSupportedError,
|
||||
A2AClientError,
|
||||
A2AClientHTTPError,
|
||||
A2AClientJSONError,
|
||||
MissingAPIKeyError,
|
||||
)
|
||||
|
||||
from src.schemas.a2a.validators import (
|
||||
validate_base64,
|
||||
validate_file_content,
|
||||
validate_message_parts,
|
||||
text_to_parts,
|
||||
parts_to_text,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# From types
|
||||
"TaskState",
|
||||
"TextPart",
|
||||
"FileContent",
|
||||
"FilePart",
|
||||
"DataPart",
|
||||
"Part",
|
||||
"Message",
|
||||
"TaskStatus",
|
||||
"Artifact",
|
||||
"Task",
|
||||
"TaskStatusUpdateEvent",
|
||||
"TaskArtifactUpdateEvent",
|
||||
"AuthenticationInfo",
|
||||
"PushNotificationConfig",
|
||||
"TaskIdParams",
|
||||
"TaskQueryParams",
|
||||
"TaskSendParams",
|
||||
"TaskPushNotificationConfig",
|
||||
"JSONRPCMessage",
|
||||
"JSONRPCRequest",
|
||||
"JSONRPCResponse",
|
||||
"JSONRPCError",
|
||||
"SendTaskRequest",
|
||||
"SendTaskResponse",
|
||||
"SendTaskStreamingRequest",
|
||||
"SendTaskStreamingResponse",
|
||||
"GetTaskRequest",
|
||||
"GetTaskResponse",
|
||||
"CancelTaskRequest",
|
||||
"CancelTaskResponse",
|
||||
"SetTaskPushNotificationRequest",
|
||||
"SetTaskPushNotificationResponse",
|
||||
"GetTaskPushNotificationRequest",
|
||||
"GetTaskPushNotificationResponse",
|
||||
"TaskResubscriptionRequest",
|
||||
"A2ARequest",
|
||||
"AgentProvider",
|
||||
"AgentCapabilities",
|
||||
"AgentAuthentication",
|
||||
"AgentSkill",
|
||||
"AgentCard",
|
||||
# From exceptions
|
||||
"JSONParseError",
|
||||
"InvalidRequestError",
|
||||
"MethodNotFoundError",
|
||||
"InvalidParamsError",
|
||||
"InternalError",
|
||||
"TaskNotFoundError",
|
||||
"TaskNotCancelableError",
|
||||
"PushNotificationNotSupportedError",
|
||||
"UnsupportedOperationError",
|
||||
"ContentTypeNotSupportedError",
|
||||
"A2AClientError",
|
||||
"A2AClientHTTPError",
|
||||
"A2AClientJSONError",
|
||||
"MissingAPIKeyError",
|
||||
# From validators
|
||||
"validate_base64",
|
||||
"validate_file_content",
|
||||
"validate_message_parts",
|
||||
"text_to_parts",
|
||||
"parts_to_text",
|
||||
]
|
@ -1,147 +0,0 @@
|
||||
"""
|
||||
A2A (Agent-to-Agent) protocol exception definitions.
|
||||
|
||||
This module contains error types and exceptions for the A2A protocol.
|
||||
"""
|
||||
|
||||
from src.schemas.a2a.types import JSONRPCError
|
||||
|
||||
|
||||
class JSONParseError(JSONRPCError):
|
||||
"""
|
||||
Error raised when JSON parsing fails.
|
||||
"""
|
||||
|
||||
code: int = -32700
|
||||
message: str = "Invalid JSON payload"
|
||||
data: object | None = None
|
||||
|
||||
|
||||
class InvalidRequestError(JSONRPCError):
|
||||
"""
|
||||
Error raised when request validation fails.
|
||||
"""
|
||||
|
||||
code: int = -32600
|
||||
message: str = "Request payload validation error"
|
||||
data: object | None = None
|
||||
|
||||
|
||||
class MethodNotFoundError(JSONRPCError):
|
||||
"""
|
||||
Error raised when the requested method is not found.
|
||||
"""
|
||||
|
||||
code: int = -32601
|
||||
message: str = "Method not found"
|
||||
data: None = None
|
||||
|
||||
|
||||
class InvalidParamsError(JSONRPCError):
|
||||
"""
|
||||
Error raised when the parameters are invalid.
|
||||
"""
|
||||
|
||||
code: int = -32602
|
||||
message: str = "Invalid parameters"
|
||||
data: object | None = None
|
||||
|
||||
|
||||
class InternalError(JSONRPCError):
|
||||
"""
|
||||
Error raised when an internal error occurs.
|
||||
"""
|
||||
|
||||
code: int = -32603
|
||||
message: str = "Internal error"
|
||||
data: object | None = None
|
||||
|
||||
|
||||
class TaskNotFoundError(JSONRPCError):
|
||||
"""
|
||||
Error raised when the requested task is not found.
|
||||
"""
|
||||
|
||||
code: int = -32001
|
||||
message: str = "Task not found"
|
||||
data: None = None
|
||||
|
||||
|
||||
class TaskNotCancelableError(JSONRPCError):
|
||||
"""
|
||||
Error raised when a task cannot be canceled.
|
||||
"""
|
||||
|
||||
code: int = -32002
|
||||
message: str = "Task cannot be canceled"
|
||||
data: None = None
|
||||
|
||||
|
||||
class PushNotificationNotSupportedError(JSONRPCError):
|
||||
"""
|
||||
Error raised when push notifications are not supported.
|
||||
"""
|
||||
|
||||
code: int = -32003
|
||||
message: str = "Push Notification is not supported"
|
||||
data: None = None
|
||||
|
||||
|
||||
class UnsupportedOperationError(JSONRPCError):
|
||||
"""
|
||||
Error raised when an operation is not supported.
|
||||
"""
|
||||
|
||||
code: int = -32004
|
||||
message: str = "This operation is not supported"
|
||||
data: None = None
|
||||
|
||||
|
||||
class ContentTypeNotSupportedError(JSONRPCError):
|
||||
"""
|
||||
Error raised when content types are incompatible.
|
||||
"""
|
||||
|
||||
code: int = -32005
|
||||
message: str = "Incompatible content types"
|
||||
data: None = None
|
||||
|
||||
|
||||
# Client exceptions
|
||||
|
||||
|
||||
class A2AClientError(Exception):
|
||||
"""
|
||||
Base exception for A2A client errors.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class A2AClientHTTPError(A2AClientError):
|
||||
"""
|
||||
Exception for HTTP errors in A2A client.
|
||||
"""
|
||||
|
||||
def __init__(self, status_code: int, message: str):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(f"HTTP Error {status_code}: {message}")
|
||||
|
||||
|
||||
class A2AClientJSONError(A2AClientError):
|
||||
"""
|
||||
Exception for JSON errors in A2A client.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(f"JSON Error: {message}")
|
||||
|
||||
|
||||
class MissingAPIKeyError(Exception):
|
||||
"""
|
||||
Exception for missing API key.
|
||||
"""
|
||||
|
||||
pass
|
@ -1,124 +0,0 @@
|
||||
"""
|
||||
A2A (Agent-to-Agent) protocol validators.
|
||||
|
||||
This module contains validators for the A2A protocol data.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
import base64
|
||||
import re
|
||||
from pydantic import ValidationError
|
||||
import logging
|
||||
from src.schemas.a2a.types import Part, TextPart, FilePart, DataPart, FileContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_base64(value: str) -> bool:
|
||||
"""
|
||||
Validates if a string is valid base64.
|
||||
|
||||
Args:
|
||||
value: String to validate
|
||||
|
||||
Returns:
|
||||
True if valid base64, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not value:
|
||||
return False
|
||||
|
||||
# Check if the string has base64 characters only
|
||||
pattern = r"^[A-Za-z0-9+/]+={0,2}$"
|
||||
if not re.match(pattern, value):
|
||||
return False
|
||||
|
||||
# Try to decode
|
||||
base64.b64decode(value)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid base64 string: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def validate_file_content(file_content: FileContent) -> bool:
|
||||
"""
|
||||
Validates file content.
|
||||
|
||||
Args:
|
||||
file_content: FileContent to validate
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
if file_content.bytes is not None:
|
||||
return validate_base64(file_content.bytes)
|
||||
elif file_content.uri is not None:
|
||||
# Basic URL validation
|
||||
pattern = r"^https?://.+"
|
||||
return bool(re.match(pattern, file_content.uri))
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid file content: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def validate_message_parts(parts: List[Part]) -> bool:
|
||||
"""
|
||||
Validates all parts in a message.
|
||||
|
||||
Args:
|
||||
parts: List of parts to validate
|
||||
|
||||
Returns:
|
||||
True if all parts are valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
for part in parts:
|
||||
if isinstance(part, TextPart):
|
||||
if not part.text or not isinstance(part.text, str):
|
||||
return False
|
||||
elif isinstance(part, FilePart):
|
||||
if not validate_file_content(part.file):
|
||||
return False
|
||||
elif isinstance(part, DataPart):
|
||||
if not part.data or not isinstance(part.data, dict):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
except (ValidationError, Exception) as e:
|
||||
logger.warning(f"Invalid message parts: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def text_to_parts(text: str) -> List[Part]:
|
||||
"""
|
||||
Converts a plain text to a list of message parts.
|
||||
|
||||
Args:
|
||||
text: Plain text to convert
|
||||
|
||||
Returns:
|
||||
List containing a single TextPart
|
||||
"""
|
||||
return [TextPart(text=text)]
|
||||
|
||||
|
||||
def parts_to_text(parts: List[Part]) -> str:
|
||||
"""
|
||||
Extracts text from a list of message parts.
|
||||
|
||||
Args:
|
||||
parts: List of parts to extract text from
|
||||
|
||||
Returns:
|
||||
Concatenated text from all text parts
|
||||
"""
|
||||
text = ""
|
||||
for part in parts:
|
||||
if isinstance(part, TextPart):
|
||||
text += part.text
|
||||
# Could add handling for other part types here
|
||||
return text
|
@ -1,31 +1,20 @@
|
||||
"""
|
||||
A2A (Agent-to-Agent) protocol type definitions.
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Literal, TypeVar
|
||||
from uuid import uuid4
|
||||
from typing_extensions import Self
|
||||
|
||||
This module contains Pydantic schema definitions for the A2A protocol.
|
||||
"""
|
||||
|
||||
from typing import Union, Any, List, Optional, Annotated, Literal
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
TypeAdapter,
|
||||
field_serializer,
|
||||
model_validator,
|
||||
ConfigDict,
|
||||
)
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
from enum import Enum
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class TaskState(str, Enum):
|
||||
"""
|
||||
Enum for the state of a task in the A2A protocol.
|
||||
|
||||
States follow the A2A protocol specification.
|
||||
"""
|
||||
|
||||
SUBMITTED = "submitted"
|
||||
WORKING = "working"
|
||||
INPUT_REQUIRED = "input-required"
|
||||
@ -36,22 +25,12 @@ class TaskState(str, Enum):
|
||||
|
||||
|
||||
class TextPart(BaseModel):
|
||||
"""
|
||||
Represents a text part in a message.
|
||||
"""
|
||||
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class FileContent(BaseModel):
|
||||
"""
|
||||
Represents file content in a file part.
|
||||
|
||||
Either bytes or uri must be provided, but not both.
|
||||
"""
|
||||
|
||||
name: str | None = None
|
||||
mimeType: str | None = None
|
||||
bytes: str | None = None
|
||||
@ -59,9 +38,6 @@ class FileContent(BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_content(self) -> Self:
|
||||
"""
|
||||
Validates that either bytes or uri is present, but not both.
|
||||
"""
|
||||
if not (self.bytes or self.uri):
|
||||
raise ValueError("Either 'bytes' or 'uri' must be present in the file data")
|
||||
if self.bytes and self.uri:
|
||||
@ -72,65 +48,40 @@ class FileContent(BaseModel):
|
||||
|
||||
|
||||
class FilePart(BaseModel):
|
||||
"""
|
||||
Represents a file part in a message.
|
||||
"""
|
||||
|
||||
type: Literal["file"] = "file"
|
||||
file: FileContent
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class DataPart(BaseModel):
|
||||
"""
|
||||
Represents a data part in a message.
|
||||
"""
|
||||
|
||||
type: Literal["data"] = "data"
|
||||
data: dict[str, Any]
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator="type")]
|
||||
Part = Annotated[TextPart | FilePart | DataPart, Field(discriminator="type")]
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""
|
||||
Represents a message in the A2A protocol.
|
||||
|
||||
A message consists of a role and one or more parts.
|
||||
"""
|
||||
|
||||
role: Literal["user", "agent"]
|
||||
parts: List[Part]
|
||||
parts: list[Part]
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskStatus(BaseModel):
|
||||
"""
|
||||
Represents the status of a task.
|
||||
"""
|
||||
|
||||
state: TaskState
|
||||
message: Message | None = None
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
@field_serializer("timestamp")
|
||||
def serialize_dt(self, dt: datetime, _info):
|
||||
"""
|
||||
Serializes datetime to ISO format.
|
||||
"""
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
class Artifact(BaseModel):
|
||||
"""
|
||||
Represents an artifact produced by an agent.
|
||||
"""
|
||||
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
parts: List[Part]
|
||||
parts: list[Part]
|
||||
metadata: dict[str, Any] | None = None
|
||||
index: int = 0
|
||||
append: bool | None = None
|
||||
@ -138,23 +89,15 @@ class Artifact(BaseModel):
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
"""
|
||||
Represents a task in the A2A protocol.
|
||||
"""
|
||||
|
||||
id: str
|
||||
sessionId: str | None = None
|
||||
status: TaskStatus
|
||||
artifacts: List[Artifact] | None = None
|
||||
history: List[Message] | None = None
|
||||
artifacts: list[Artifact] | None = None
|
||||
history: list[Message] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskStatusUpdateEvent(BaseModel):
|
||||
"""
|
||||
Represents a task status update event.
|
||||
"""
|
||||
|
||||
id: str
|
||||
status: TaskStatus
|
||||
final: bool = False
|
||||
@ -162,295 +105,234 @@ class TaskStatusUpdateEvent(BaseModel):
|
||||
|
||||
|
||||
class TaskArtifactUpdateEvent(BaseModel):
|
||||
"""
|
||||
Represents a task artifact update event.
|
||||
"""
|
||||
|
||||
id: str
|
||||
artifact: Artifact
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AuthenticationInfo(BaseModel):
|
||||
"""
|
||||
Represents authentication information for push notifications.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
schemes: List[str]
|
||||
schemes: list[str]
|
||||
credentials: str | None = None
|
||||
|
||||
|
||||
class PushNotificationConfig(BaseModel):
|
||||
"""
|
||||
Represents push notification configuration.
|
||||
"""
|
||||
|
||||
url: str
|
||||
token: str | None = None
|
||||
authentication: AuthenticationInfo | None = None
|
||||
|
||||
|
||||
class TaskIdParams(BaseModel):
|
||||
"""
|
||||
Represents parameters for identifying a task.
|
||||
"""
|
||||
|
||||
id: str
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskQueryParams(TaskIdParams):
|
||||
"""
|
||||
Represents parameters for querying a task.
|
||||
"""
|
||||
|
||||
historyLength: int | None = None
|
||||
|
||||
|
||||
class TaskSendParams(BaseModel):
|
||||
"""
|
||||
Represents parameters for sending a task.
|
||||
"""
|
||||
|
||||
id: str
|
||||
sessionId: str = Field(default_factory=lambda: uuid4().hex)
|
||||
message: Message
|
||||
acceptedOutputModes: Optional[List[str]] = None
|
||||
acceptedOutputModes: list[str] | None = None
|
||||
pushNotification: PushNotificationConfig | None = None
|
||||
historyLength: int | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
agentId: str = ""
|
||||
|
||||
|
||||
class TaskPushNotificationConfig(BaseModel):
|
||||
"""
|
||||
Represents push notification configuration for a task.
|
||||
"""
|
||||
|
||||
id: str
|
||||
pushNotificationConfig: PushNotificationConfig
|
||||
|
||||
|
||||
# RPC Messages
|
||||
## RPC Messages
|
||||
|
||||
|
||||
class JSONRPCMessage(BaseModel):
|
||||
"""
|
||||
Base class for JSON-RPC messages.
|
||||
"""
|
||||
|
||||
jsonrpc: Literal["2.0"] = "2.0"
|
||||
id: int | str | None = Field(default_factory=lambda: uuid4().hex)
|
||||
|
||||
|
||||
class JSONRPCRequest(JSONRPCMessage):
|
||||
"""
|
||||
Represents a JSON-RPC request.
|
||||
"""
|
||||
|
||||
method: str
|
||||
params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class JSONRPCError(BaseModel):
|
||||
"""
|
||||
Represents a JSON-RPC error.
|
||||
"""
|
||||
|
||||
code: int
|
||||
message: str
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class JSONRPCResponse(JSONRPCMessage):
|
||||
"""
|
||||
Represents a JSON-RPC response.
|
||||
"""
|
||||
|
||||
result: Any | None = None
|
||||
error: JSONRPCError | None = None
|
||||
|
||||
|
||||
class SendTaskRequest(JSONRPCRequest):
|
||||
"""
|
||||
Represents a request to send a task.
|
||||
"""
|
||||
|
||||
method: Literal["tasks/send"] = "tasks/send"
|
||||
params: TaskSendParams
|
||||
|
||||
|
||||
class SendTaskResponse(JSONRPCResponse):
|
||||
"""
|
||||
Represents a response to a send task request.
|
||||
"""
|
||||
|
||||
result: Task | None = None
|
||||
|
||||
|
||||
class SendTaskStreamingRequest(JSONRPCRequest):
|
||||
"""
|
||||
Represents a request to send a task with streaming.
|
||||
"""
|
||||
|
||||
method: Literal["tasks/sendSubscribe"] = "tasks/sendSubscribe"
|
||||
params: TaskSendParams
|
||||
|
||||
|
||||
class SendTaskStreamingResponse(JSONRPCResponse):
|
||||
"""
|
||||
Represents a streaming response to a send task request.
|
||||
"""
|
||||
|
||||
result: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None
|
||||
|
||||
|
||||
class GetTaskRequest(JSONRPCRequest):
|
||||
"""
|
||||
Represents a request to get task information.
|
||||
"""
|
||||
|
||||
method: Literal["tasks/get"] = "tasks/get"
|
||||
params: TaskQueryParams
|
||||
|
||||
|
||||
class GetTaskResponse(JSONRPCResponse):
|
||||
"""
|
||||
Represents a response to a get task request.
|
||||
"""
|
||||
|
||||
result: Task | None = None
|
||||
|
||||
|
||||
class CancelTaskRequest(JSONRPCRequest):
|
||||
"""
|
||||
Represents a request to cancel a task.
|
||||
"""
|
||||
|
||||
method: Literal["tasks/cancel",] = "tasks/cancel"
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
class CancelTaskResponse(JSONRPCResponse):
|
||||
"""
|
||||
Represents a response to a cancel task request.
|
||||
"""
|
||||
|
||||
result: Task | None = None
|
||||
|
||||
|
||||
class SetTaskPushNotificationRequest(JSONRPCRequest):
|
||||
"""
|
||||
Represents a request to set push notification for a task.
|
||||
"""
|
||||
|
||||
method: Literal["tasks/pushNotification/set",] = "tasks/pushNotification/set"
|
||||
params: TaskPushNotificationConfig
|
||||
|
||||
|
||||
class SetTaskPushNotificationResponse(JSONRPCResponse):
|
||||
"""
|
||||
Represents a response to a set push notification request.
|
||||
"""
|
||||
|
||||
result: TaskPushNotificationConfig | None = None
|
||||
|
||||
|
||||
class GetTaskPushNotificationRequest(JSONRPCRequest):
|
||||
"""
|
||||
Represents a request to get push notification configuration for a task.
|
||||
"""
|
||||
|
||||
method: Literal["tasks/pushNotification/get",] = "tasks/pushNotification/get"
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
class GetTaskPushNotificationResponse(JSONRPCResponse):
|
||||
"""
|
||||
Represents a response to a get push notification request.
|
||||
"""
|
||||
|
||||
result: TaskPushNotificationConfig | None = None
|
||||
|
||||
|
||||
class TaskResubscriptionRequest(JSONRPCRequest):
|
||||
"""
|
||||
Represents a request to resubscribe to a task.
|
||||
"""
|
||||
|
||||
method: Literal["tasks/resubscribe",] = "tasks/resubscribe"
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
# TypeAdapter for discriminating A2A requests by method
|
||||
A2ARequest = TypeAdapter(
|
||||
Annotated[
|
||||
Union[
|
||||
SendTaskRequest,
|
||||
GetTaskRequest,
|
||||
CancelTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationRequest,
|
||||
TaskResubscriptionRequest,
|
||||
SendTaskStreamingRequest,
|
||||
],
|
||||
SendTaskRequest
|
||||
| GetTaskRequest
|
||||
| CancelTaskRequest
|
||||
| SetTaskPushNotificationRequest
|
||||
| GetTaskPushNotificationRequest
|
||||
| TaskResubscriptionRequest
|
||||
| SendTaskStreamingRequest,
|
||||
Field(discriminator="method"),
|
||||
]
|
||||
)
|
||||
|
||||
## Error types
|
||||
|
||||
# Agent Card schemas
|
||||
|
||||
class JSONParseError(JSONRPCError):
|
||||
code: int = -32700
|
||||
message: str = "Invalid JSON payload"
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class InvalidRequestError(JSONRPCError):
|
||||
code: int = -32600
|
||||
message: str = "Request payload validation error"
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class MethodNotFoundError(JSONRPCError):
|
||||
code: int = -32601
|
||||
message: str = "Method not found"
|
||||
data: None = None
|
||||
|
||||
|
||||
class InvalidParamsError(JSONRPCError):
|
||||
code: int = -32602
|
||||
message: str = "Invalid parameters"
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class InternalError(JSONRPCError):
|
||||
code: int = -32603
|
||||
message: str = "Internal error"
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class TaskNotFoundError(JSONRPCError):
|
||||
code: int = -32001
|
||||
message: str = "Task not found"
|
||||
data: None = None
|
||||
|
||||
|
||||
class TaskNotCancelableError(JSONRPCError):
|
||||
code: int = -32002
|
||||
message: str = "Task cannot be canceled"
|
||||
data: None = None
|
||||
|
||||
|
||||
class PushNotificationNotSupportedError(JSONRPCError):
|
||||
code: int = -32003
|
||||
message: str = "Push Notification is not supported"
|
||||
data: None = None
|
||||
|
||||
|
||||
class UnsupportedOperationError(JSONRPCError):
|
||||
code: int = -32004
|
||||
message: str = "This operation is not supported"
|
||||
data: None = None
|
||||
|
||||
|
||||
class ContentTypeNotSupportedError(JSONRPCError):
|
||||
code: int = -32005
|
||||
message: str = "Incompatible content types"
|
||||
data: None = None
|
||||
|
||||
|
||||
class AgentProvider(BaseModel):
|
||||
"""
|
||||
Represents the provider of an agent.
|
||||
"""
|
||||
|
||||
organization: str
|
||||
url: str | None = None
|
||||
|
||||
|
||||
class AgentCapabilities(BaseModel):
|
||||
"""
|
||||
Represents the capabilities of an agent.
|
||||
"""
|
||||
|
||||
streaming: bool = False
|
||||
pushNotifications: bool = False
|
||||
stateTransitionHistory: bool = False
|
||||
|
||||
|
||||
class AgentAuthentication(BaseModel):
|
||||
"""
|
||||
Represents the authentication requirements for an agent.
|
||||
"""
|
||||
|
||||
schemes: List[str]
|
||||
schemes: list[str]
|
||||
credentials: str | None = None
|
||||
|
||||
|
||||
class AgentSkill(BaseModel):
|
||||
"""
|
||||
Represents a skill of an agent.
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
tags: List[str] | None = None
|
||||
examples: List[str] | None = None
|
||||
inputModes: List[str] | None = None
|
||||
outputModes: List[str] | None = None
|
||||
tags: list[str] | None = None
|
||||
examples: list[str] | None = None
|
||||
inputModes: list[str] | None = None
|
||||
outputModes: list[str] | None = None
|
||||
|
||||
|
||||
class AgentCard(BaseModel):
|
||||
"""
|
||||
Represents an agent card in the A2A protocol.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str | None = None
|
||||
url: str
|
||||
@ -459,6 +341,27 @@ class AgentCard(BaseModel):
|
||||
documentationUrl: str | None = None
|
||||
capabilities: AgentCapabilities
|
||||
authentication: AgentAuthentication | None = None
|
||||
defaultInputModes: List[str] = ["text"]
|
||||
defaultOutputModes: List[str] = ["text"]
|
||||
skills: List[AgentSkill]
|
||||
defaultInputModes: list[str] = ["text"]
|
||||
defaultOutputModes: list[str] = ["text"]
|
||||
skills: list[AgentSkill]
|
||||
|
||||
|
||||
class A2AClientError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class A2AClientHTTPError(A2AClientError):
|
||||
def __init__(self, status_code: int, message: str):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(f"HTTP Error {status_code}: {message}")
|
||||
|
||||
|
||||
class A2AClientJSONError(A2AClientError):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(f"JSON Error: {message}")
|
||||
|
||||
|
||||
class MissingAPIKeyError(Exception):
|
||||
"""Exception for missing API key."""
|
@ -1,9 +1 @@
|
||||
from .agent_runner import run_agent
|
||||
from .redis_cache_service import RedisCacheService
|
||||
from .a2a_task_manager_service import A2ATaskManager
|
||||
from .a2a_server_service import A2AServer
|
||||
from .a2a_integration_service import (
|
||||
AgentRunnerAdapter,
|
||||
StreamingServiceAdapter,
|
||||
create_agent_card_from_agent,
|
||||
)
|
||||
|
@ -7,7 +7,7 @@ import json
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from src.schemas.a2a.types import (
|
||||
from src.schemas.a2a_types import (
|
||||
GetTaskRequest,
|
||||
SendTaskRequest,
|
||||
Message,
|
||||
|
@ -1,507 +0,0 @@
|
||||
"""
|
||||
A2A Integration Service.
|
||||
|
||||
This service provides adapters to integrate existing services with the A2A protocol.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, AsyncIterable
|
||||
|
||||
from src.schemas.a2a import (
|
||||
AgentCard,
|
||||
AgentCapabilities,
|
||||
AgentProvider,
|
||||
Artifact,
|
||||
Message,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskStatus,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentRunnerAdapter:
|
||||
"""
|
||||
Adapter for integrating the existing agent runner with the A2A protocol.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_runner_func,
|
||||
session_service=None,
|
||||
artifacts_service=None,
|
||||
memory_service=None,
|
||||
db=None,
|
||||
):
|
||||
"""
|
||||
Initialize the adapter.
|
||||
|
||||
Args:
|
||||
agent_runner_func: The agent runner function (e.g., run_agent)
|
||||
session_service: Session service for message history
|
||||
artifacts_service: Artifacts service for artifact history
|
||||
memory_service: Memory service for agent memory
|
||||
db: Database session
|
||||
"""
|
||||
self.agent_runner_func = agent_runner_func
|
||||
self.session_service = session_service
|
||||
self.artifacts_service = artifacts_service
|
||||
self.memory_service = memory_service
|
||||
self.db = db
|
||||
|
||||
async def get_supported_modes(self) -> List[str]:
|
||||
"""
|
||||
Get the supported output modes for the agent.
|
||||
|
||||
Returns:
|
||||
List of supported output modes
|
||||
"""
|
||||
# Default modes, can be extended based on agent configuration
|
||||
return ["text", "application/json"]
|
||||
|
||||
async def run_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
message: str,
|
||||
session_id: Optional[str] = None,
|
||||
task_id: Optional[str] = None,
|
||||
db=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the agent with the given message.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to run
|
||||
message: User message to process
|
||||
session_id: Optional session ID for conversation context
|
||||
task_id: Optional task ID for tracking
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary with the agent's response
|
||||
"""
|
||||
|
||||
try:
|
||||
# Use the existing agent runner function
|
||||
# Usar o session_id fornecido, ou gerar um novo
|
||||
session_id = session_id or str(uuid.uuid4())
|
||||
task_id = task_id or str(uuid.uuid4())
|
||||
|
||||
# Use the provided db or fallback to self.db
|
||||
db_session = db if db is not None else self.db
|
||||
|
||||
response_text = await self.agent_runner_func(
|
||||
agent_id=agent_id,
|
||||
contact_id=task_id,
|
||||
message=message,
|
||||
session_service=self.session_service,
|
||||
artifacts_service=self.artifacts_service,
|
||||
memory_service=self.memory_service,
|
||||
db=db_session,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Format the response to include both the A2A-compatible message
|
||||
# and the artifact format to match the Google A2A implementation
|
||||
# Nota: O formato dos artifacts é simplificado para compatibilidade com Google A2A
|
||||
message_obj = {
|
||||
"role": "agent",
|
||||
"parts": [{"type": "text", "text": response_text}],
|
||||
}
|
||||
|
||||
# Formato de artefato compatível com Google A2A
|
||||
artifact_obj = {
|
||||
"parts": [{"type": "text", "text": response_text}],
|
||||
"index": 0,
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"content": response_text,
|
||||
"session_id": session_id,
|
||||
"task_id": task_id,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": message_obj,
|
||||
"artifact": artifact_obj,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AGENT-RUNNER] Error running agent: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"session_id": session_id,
|
||||
"task_id": task_id,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": {
|
||||
"role": "agent",
|
||||
"parts": [{"type": "text", "text": f"Error: {str(e)}"}],
|
||||
},
|
||||
}
|
||||
|
||||
async def cancel_task(self, task_id: str) -> bool:
|
||||
"""
|
||||
Cancel a running task.
|
||||
|
||||
Args:
|
||||
task_id: ID of the task to cancel
|
||||
|
||||
Returns:
|
||||
True if successfully canceled, False otherwise
|
||||
"""
|
||||
# Currently, the agent runner doesn't support cancellation
|
||||
# This is a placeholder for future implementation
|
||||
logger.warning(f"Task cancellation not implemented for task {task_id}")
|
||||
return False
|
||||
|
||||
|
||||
class StreamingServiceAdapter:
|
||||
"""
|
||||
Adapter for integrating the existing streaming service with the A2A protocol.
|
||||
"""
|
||||
|
||||
def __init__(self, streaming_service):
|
||||
"""
|
||||
Initialize the adapter.
|
||||
|
||||
Args:
|
||||
streaming_service: The streaming service instance
|
||||
"""
|
||||
self.streaming_service = streaming_service
|
||||
|
||||
async def stream_agent_response(
|
||||
self,
|
||||
agent_id: str,
|
||||
message: str,
|
||||
api_key: str,
|
||||
session_id: Optional[str] = None,
|
||||
task_id: Optional[str] = None,
|
||||
db=None,
|
||||
) -> AsyncIterable[str]:
|
||||
"""
|
||||
Stream the agent's response as A2A events.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
message: User message to process
|
||||
api_key: API key for authentication
|
||||
session_id: Optional session ID for conversation context
|
||||
task_id: Optional task ID for tracking
|
||||
db: Database session
|
||||
|
||||
Yields:
|
||||
A2A event objects as JSON strings for SSE (Server-Sent Events)
|
||||
"""
|
||||
task_id = task_id or str(uuid.uuid4())
|
||||
logger.info(f"Starting streaming response for task {task_id}")
|
||||
|
||||
# Set working status event
|
||||
working_status = TaskStatus(
|
||||
state="working",
|
||||
timestamp=datetime.now(),
|
||||
message=Message(
|
||||
role="agent", parts=[TextPart(text="Processing your request...")]
|
||||
),
|
||||
)
|
||||
|
||||
status_event = TaskStatusUpdateEvent(
|
||||
id=task_id, status=working_status, final=False
|
||||
)
|
||||
yield json.dumps(status_event.model_dump())
|
||||
|
||||
content_buffer = ""
|
||||
final_sent = False
|
||||
has_error = False
|
||||
|
||||
# Stream from the existing streaming service
|
||||
try:
|
||||
logger.info(f"Setting up streaming for agent {agent_id}, task {task_id}")
|
||||
# To streaming, we use task_id as contact_id
|
||||
contact_id = task_id
|
||||
|
||||
last_event_time = datetime.now()
|
||||
heartbeat_interval = 20
|
||||
|
||||
async for event in self.streaming_service.send_task_streaming(
|
||||
agent_id=agent_id,
|
||||
api_key=api_key,
|
||||
message=message,
|
||||
contact_id=contact_id,
|
||||
session_id=session_id,
|
||||
db=db,
|
||||
):
|
||||
last_event_time = datetime.now()
|
||||
|
||||
# Process the streaming event format
|
||||
event_data = event.get("data", "{}")
|
||||
try:
|
||||
logger.info(f"Processing event data: {event_data[:100]}...")
|
||||
data = json.loads(event_data)
|
||||
|
||||
# Extract content
|
||||
if "delta" in data and data["delta"].get("content"):
|
||||
content = data["delta"]["content"]
|
||||
content_buffer += content
|
||||
logger.info(f"Received content chunk: {content[:50]}...")
|
||||
|
||||
# Create artifact update event
|
||||
artifact = Artifact(
|
||||
name="response",
|
||||
parts=[TextPart(text=content)],
|
||||
index=0,
|
||||
append=True,
|
||||
lastChunk=False,
|
||||
)
|
||||
|
||||
artifact_event = TaskArtifactUpdateEvent(
|
||||
id=task_id, artifact=artifact
|
||||
)
|
||||
yield json.dumps(artifact_event.model_dump())
|
||||
|
||||
# Check if final event
|
||||
if data.get("done", False) and not final_sent:
|
||||
logger.info(f"Received final event for task {task_id}")
|
||||
# Create completed status event
|
||||
completed_status = TaskStatus(
|
||||
state="completed",
|
||||
timestamp=datetime.now(),
|
||||
message=Message(
|
||||
role="agent",
|
||||
parts=[
|
||||
TextPart(text=content_buffer or "Task completed")
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Final artifact with full content
|
||||
final_artifact = Artifact(
|
||||
name="response",
|
||||
parts=[TextPart(text=content_buffer)],
|
||||
index=0,
|
||||
append=False,
|
||||
lastChunk=True,
|
||||
)
|
||||
|
||||
# Send the final artifact
|
||||
final_artifact_event = TaskArtifactUpdateEvent(
|
||||
id=task_id, artifact=final_artifact
|
||||
)
|
||||
|
||||
yield json.dumps(final_artifact_event.model_dump())
|
||||
|
||||
# Send the completed status
|
||||
final_status_event = TaskStatusUpdateEvent(
|
||||
id=task_id,
|
||||
status=completed_status,
|
||||
final=True,
|
||||
)
|
||||
|
||||
yield json.dumps(final_status_event.model_dump())
|
||||
|
||||
final_sent = True
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"Received non-JSON event data: {e}. Data: {event_data[:100]}..."
|
||||
)
|
||||
# Handle non-JSON events - simply add to buffer as text
|
||||
if isinstance(event_data, str):
|
||||
content_buffer += event_data
|
||||
|
||||
# Create artifact update event
|
||||
artifact = Artifact(
|
||||
name="response",
|
||||
parts=[TextPart(text=event_data)],
|
||||
index=0,
|
||||
append=True,
|
||||
lastChunk=False,
|
||||
)
|
||||
|
||||
artifact_event = TaskArtifactUpdateEvent(
|
||||
id=task_id, artifact=artifact
|
||||
)
|
||||
|
||||
yield json.dumps(artifact_event.model_dump())
|
||||
elif isinstance(event_data, dict):
|
||||
# Try to extract text from the dictionary
|
||||
text_value = str(event_data)
|
||||
content_buffer += text_value
|
||||
|
||||
artifact = Artifact(
|
||||
name="response",
|
||||
parts=[TextPart(text=text_value)],
|
||||
index=0,
|
||||
append=True,
|
||||
lastChunk=False,
|
||||
)
|
||||
|
||||
artifact_event = TaskArtifactUpdateEvent(
|
||||
id=task_id, artifact=artifact
|
||||
)
|
||||
|
||||
yield json.dumps(artifact_event.model_dump())
|
||||
|
||||
# Send heartbeat/keep-alive to keep the SSE connection open
|
||||
now = datetime.now()
|
||||
if (now - last_event_time).total_seconds() > heartbeat_interval:
|
||||
logger.info(f"Sending heartbeat for task {task_id}")
|
||||
# Sending keep-alive event as a "working" status event
|
||||
working_heartbeat = TaskStatus(
|
||||
state="working",
|
||||
timestamp=now,
|
||||
message=Message(
|
||||
role="agent", parts=[TextPart(text="Still processing...")]
|
||||
),
|
||||
)
|
||||
heartbeat_event = TaskStatusUpdateEvent(
|
||||
id=task_id, status=working_heartbeat, final=False
|
||||
)
|
||||
yield json.dumps(heartbeat_event.model_dump())
|
||||
last_event_time = now
|
||||
|
||||
# Ensure we send a final event if not already sent
|
||||
if not final_sent:
|
||||
logger.info(
|
||||
f"Stream completed for task {task_id}, sending final status"
|
||||
)
|
||||
# Create completed status event
|
||||
completed_status = TaskStatus(
|
||||
state="completed",
|
||||
timestamp=datetime.now(),
|
||||
message=Message(
|
||||
role="agent",
|
||||
parts=[TextPart(text=content_buffer or "Task completed")],
|
||||
),
|
||||
)
|
||||
|
||||
# Send the completed status
|
||||
final_event = TaskStatusUpdateEvent(
|
||||
id=task_id, status=completed_status, final=True
|
||||
)
|
||||
yield json.dumps(final_event.model_dump())
|
||||
|
||||
except Exception as e:
|
||||
has_error = True
|
||||
logger.error(f"Error in streaming for task {task_id}: {e}", exc_info=True)
|
||||
|
||||
# Create failed status event
|
||||
failed_status = TaskStatus(
|
||||
state="failed",
|
||||
timestamp=datetime.now(),
|
||||
message=Message(
|
||||
role="agent",
|
||||
parts=[
|
||||
TextPart(
|
||||
text=f"Error during streaming: {str(e)}. Partial response: {content_buffer[:200] if content_buffer else 'No content received'}"
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
error_event = TaskStatusUpdateEvent(
|
||||
id=task_id, status=failed_status, final=True
|
||||
)
|
||||
yield json.dumps(error_event.model_dump())
|
||||
|
||||
finally:
|
||||
# Ensure we send a final event to properly close the connection
|
||||
if not final_sent and not has_error:
|
||||
logger.info(f"Stream finalizing for task {task_id} via finally block")
|
||||
try:
|
||||
# Create completed status event
|
||||
completed_status = TaskStatus(
|
||||
state="completed",
|
||||
timestamp=datetime.now(),
|
||||
message=Message(
|
||||
role="agent",
|
||||
parts=[
|
||||
TextPart(
|
||||
text=content_buffer or "Task completed (forced end)"
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Send the completed status
|
||||
final_event = TaskStatusUpdateEvent(
|
||||
id=task_id, status=completed_status, final=True
|
||||
)
|
||||
yield json.dumps(final_event.model_dump())
|
||||
except Exception as final_error:
|
||||
logger.error(
|
||||
f"Error sending final event in finally block: {final_error}"
|
||||
)
|
||||
|
||||
logger.info(f"Streaming completed for task {task_id}")
|
||||
|
||||
|
||||
def create_agent_card_from_agent(agent, db) -> AgentCard:
|
||||
"""
|
||||
Create an A2A agent card from an agent model.
|
||||
|
||||
Args:
|
||||
agent: The agent model from the database
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
A2A AgentCard object
|
||||
"""
|
||||
import os
|
||||
from src.api.agent_routes import format_agent_tools
|
||||
import asyncio
|
||||
|
||||
# Extract agent configuration
|
||||
agent_config = agent.config
|
||||
has_streaming = True # Assuming streaming is always supported
|
||||
has_push = True # Assuming push notifications are supported
|
||||
|
||||
# Format tools as skills
|
||||
try:
|
||||
# We use a different approach to handle the asynchronous function
|
||||
mcp_servers = agent_config.get("mcp_servers", [])
|
||||
|
||||
# We create a new thread to execute the asynchronous function
|
||||
import concurrent.futures
|
||||
|
||||
def run_async(coro):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
result = loop.run_until_complete(coro)
|
||||
loop.close()
|
||||
return result
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_async, format_agent_tools(mcp_servers, db))
|
||||
skills = future.result()
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting agent tools: {e}")
|
||||
skills = []
|
||||
|
||||
# Create agent card
|
||||
return AgentCard(
|
||||
name=agent.name,
|
||||
description=agent.description,
|
||||
url=f"{os.getenv('API_URL', '')}/api/v1/a2a/{agent.id}",
|
||||
provider=AgentProvider(
|
||||
organization=os.getenv("ORGANIZATION_NAME", ""),
|
||||
url=os.getenv("ORGANIZATION_URL", ""),
|
||||
),
|
||||
version=os.getenv("API_VERSION", "1.0.0"),
|
||||
capabilities=AgentCapabilities(
|
||||
streaming=has_streaming,
|
||||
pushNotifications=has_push,
|
||||
stateTransitionHistory=True,
|
||||
),
|
||||
authentication={
|
||||
"schemes": ["apiKey"],
|
||||
"credentials": "x-api-key",
|
||||
},
|
||||
defaultInputModes=["text", "application/json"],
|
||||
defaultOutputModes=["text", "application/json"],
|
||||
skills=skills,
|
||||
)
|
@ -1,708 +0,0 @@
|
||||
"""
|
||||
Server A2A and task manager for the A2A protocol.
|
||||
|
||||
This module implements a JSON-RPC compatible server for the A2A protocol,
|
||||
that manages agent tasks, streaming events and push notifications.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
AsyncGenerator,
|
||||
Union,
|
||||
AsyncIterable,
|
||||
)
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse, Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.schemas.a2a.types import A2ARequest
|
||||
from src.services.a2a_integration_service import (
|
||||
AgentRunnerAdapter,
|
||||
StreamingServiceAdapter,
|
||||
)
|
||||
from src.services.session_service import get_session_events
|
||||
from src.services.redis_cache_service import RedisCacheService
|
||||
from src.schemas.a2a.types import (
|
||||
SendTaskRequest,
|
||||
SendTaskStreamingRequest,
|
||||
GetTaskRequest,
|
||||
CancelTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationRequest,
|
||||
TaskResubscriptionRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class A2ATaskManager:
|
||||
"""
|
||||
Task manager for the A2A protocol.
|
||||
|
||||
This class manages the lifecycle of tasks, including:
|
||||
- Task execution
|
||||
- Streaming of events
|
||||
- Push notifications
|
||||
- Status querying
|
||||
- Cancellation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_cache: RedisCacheService,
|
||||
agent_runner: AgentRunnerAdapter,
|
||||
streaming_service: StreamingServiceAdapter,
|
||||
push_notification_service: Any = None,
|
||||
):
|
||||
"""
|
||||
Initialize the task manager.
|
||||
|
||||
Args:
|
||||
redis_cache: Cache service for storing task data
|
||||
agent_runner: Adapter for agent execution
|
||||
streaming_service: Adapter for event streaming
|
||||
push_notification_service: Service for sending push notifications
|
||||
"""
|
||||
self.cache = redis_cache
|
||||
self.agent_runner = agent_runner
|
||||
self.streaming_service = streaming_service
|
||||
self.push_notification_service = push_notification_service
|
||||
self._running_tasks = {}
|
||||
|
||||
async def on_send_task(
|
||||
self,
|
||||
task_id: str,
|
||||
agent_id: str,
|
||||
message: Dict[str, Any],
|
||||
session_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
input_mode: str = "text",
|
||||
output_modes: List[str] = ["text"],
|
||||
db: Optional[Session] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a request to send a task.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
agent_id: Agent ID
|
||||
message: User message
|
||||
session_id: Session ID (optional)
|
||||
metadata: Additional metadata (optional)
|
||||
input_mode: Input mode (text, JSON, etc.)
|
||||
output_modes: Supported output modes
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Response with task result
|
||||
"""
|
||||
if not session_id:
|
||||
session_id = f"{task_id}_{agent_id}"
|
||||
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
|
||||
# Update status to "submitted"
|
||||
task_data = {
|
||||
"id": task_id,
|
||||
"sessionId": session_id,
|
||||
"status": {
|
||||
"state": "submitted",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": None,
|
||||
"error": None,
|
||||
},
|
||||
"artifacts": [],
|
||||
"history": [],
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
# Store initial task data
|
||||
await self.cache.set(f"task:{task_id}", task_data)
|
||||
|
||||
# Check for push notification configurations
|
||||
push_config = await self.cache.get(f"task:{task_id}:push")
|
||||
if push_config and self.push_notification_service:
|
||||
# Send initial notification
|
||||
await self.push_notification_service.send_notification(
|
||||
url=push_config["url"],
|
||||
task_id=task_id,
|
||||
state="submitted",
|
||||
headers=push_config.get("headers", {}),
|
||||
)
|
||||
|
||||
try:
|
||||
# Update status to "running"
|
||||
task_data["status"].update(
|
||||
{"state": "running", "timestamp": datetime.now().isoformat()}
|
||||
)
|
||||
await self.cache.set(f"task:{task_id}", task_data)
|
||||
|
||||
# Notify "running" state
|
||||
if push_config and self.push_notification_service:
|
||||
await self.push_notification_service.send_notification(
|
||||
url=push_config["url"],
|
||||
task_id=task_id,
|
||||
state="running",
|
||||
headers=push_config.get("headers", {}),
|
||||
)
|
||||
|
||||
# Extract user message
|
||||
user_message = None
|
||||
try:
|
||||
user_message = message["parts"][0]["text"]
|
||||
except (KeyError, IndexError):
|
||||
user_message = ""
|
||||
|
||||
# Execute the agent
|
||||
response = await self.agent_runner.run_agent(
|
||||
agent_id=agent_id,
|
||||
task_id=task_id,
|
||||
message=user_message,
|
||||
session_id=session_id,
|
||||
db=db,
|
||||
)
|
||||
|
||||
# Check if the response is a dictionary (error) or a string (success)
|
||||
if isinstance(response, dict) and response.get("status") == "error":
|
||||
# Error response
|
||||
final_response = f"Error: {response.get('error', 'Unknown error')}"
|
||||
|
||||
# Update status to "failed"
|
||||
task_data["status"].update(
|
||||
{
|
||||
"state": "failed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"error": {
|
||||
"code": "AGENT_EXECUTION_ERROR",
|
||||
"message": response.get("error", "Unknown error"),
|
||||
},
|
||||
"message": {
|
||||
"role": "system",
|
||||
"parts": [{"type": "text", "text": final_response}],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Notify "failed" state
|
||||
if push_config and self.push_notification_service:
|
||||
await self.push_notification_service.send_notification(
|
||||
url=push_config["url"],
|
||||
task_id=task_id,
|
||||
state="failed",
|
||||
message={
|
||||
"role": "system",
|
||||
"parts": [{"type": "text", "text": final_response}],
|
||||
},
|
||||
headers=push_config.get("headers", {}),
|
||||
)
|
||||
else:
|
||||
# Success response
|
||||
final_response = (
|
||||
response.get("content") if isinstance(response, dict) else response
|
||||
)
|
||||
|
||||
# Update status to "completed"
|
||||
task_data["status"].update(
|
||||
{
|
||||
"state": "completed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": {
|
||||
"role": "agent",
|
||||
"parts": [{"type": "text", "text": final_response}],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Add artifacts
|
||||
if final_response:
|
||||
task_data["artifacts"].append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": final_response,
|
||||
"metadata": {
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"content_type": "text/plain",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Add history of messages
|
||||
history_length = metadata.get("historyLength", 50)
|
||||
try:
|
||||
history_messages = get_session_events(
|
||||
self.agent_runner.session_service, session_id
|
||||
)
|
||||
history_messages = history_messages[-history_length:]
|
||||
|
||||
formatted_history = []
|
||||
for event in history_messages:
|
||||
if event.content and event.content.parts:
|
||||
role = (
|
||||
"agent"
|
||||
if event.content.role == "model"
|
||||
else event.content.role
|
||||
)
|
||||
formatted_history.append(
|
||||
{
|
||||
"role": role,
|
||||
"parts": [
|
||||
{"type": "text", "text": part.text}
|
||||
for part in event.content.parts
|
||||
if part.text
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
task_data["history"] = formatted_history
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing history: {str(e)}")
|
||||
|
||||
# Notify "completed" state
|
||||
if push_config and self.push_notification_service:
|
||||
await self.push_notification_service.send_notification(
|
||||
url=push_config["url"],
|
||||
task_id=task_id,
|
||||
state="completed",
|
||||
message={
|
||||
"role": "agent",
|
||||
"parts": [{"type": "text", "text": final_response}],
|
||||
},
|
||||
headers=push_config.get("headers", {}),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing task {task_id}: {str(e)}")
|
||||
|
||||
# Update status to "failed"
|
||||
task_data["status"].update(
|
||||
{
|
||||
"state": "failed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"error": {"code": "AGENT_EXECUTION_ERROR", "message": str(e)},
|
||||
}
|
||||
)
|
||||
|
||||
# Notify "failed" state
|
||||
if push_config and self.push_notification_service:
|
||||
await self.push_notification_service.send_notification(
|
||||
url=push_config["url"],
|
||||
task_id=task_id,
|
||||
state="failed",
|
||||
message={
|
||||
"role": "system",
|
||||
"parts": [{"type": "text", "text": str(e)}],
|
||||
},
|
||||
headers=push_config.get("headers", {}),
|
||||
)
|
||||
|
||||
# Store final result
|
||||
await self.cache.set(f"task:{task_id}", task_data)
|
||||
return task_data
|
||||
|
||||
async def on_send_task_subscribe(
|
||||
self,
|
||||
task_id: str,
|
||||
agent_id: str,
|
||||
message: Dict[str, Any],
|
||||
session_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
input_mode: str = "text",
|
||||
output_modes: List[str] = ["text"],
|
||||
db: Optional[Session] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Process a request to send a task with streaming.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
agent_id: Agent ID
|
||||
message: User message
|
||||
session_id: Session ID (optional)
|
||||
metadata: Additional metadata (optional)
|
||||
input_mode: Input mode (text, JSON, etc.)
|
||||
output_modes: Supported output modes
|
||||
db: Database session
|
||||
|
||||
Yields:
|
||||
Streaming events in SSE (Server-Sent Events) format
|
||||
"""
|
||||
if not session_id:
|
||||
session_id = f"{task_id}_{agent_id}"
|
||||
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
|
||||
# Extract user message
|
||||
user_message = ""
|
||||
try:
|
||||
user_message = message["parts"][0]["text"]
|
||||
except (KeyError, IndexError):
|
||||
pass
|
||||
|
||||
# Generate streaming events
|
||||
async for event in self.streaming_service.stream_response(
|
||||
agent_id=agent_id,
|
||||
task_id=task_id,
|
||||
message=user_message,
|
||||
session_id=session_id,
|
||||
db=db,
|
||||
):
|
||||
yield event
|
||||
|
||||
async def on_get_task(self, task_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Query the status of a task by ID.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
Current task status
|
||||
|
||||
Raises:
|
||||
Exception: If the task is not found
|
||||
"""
|
||||
task_data = await self.cache.get(f"task:{task_id}")
|
||||
if not task_data:
|
||||
raise Exception(f"Task {task_id} not found")
|
||||
return task_data
|
||||
|
||||
async def on_cancel_task(self, task_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Cancel a running task.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to be cancelled
|
||||
|
||||
Returns:
|
||||
Task status after cancellation
|
||||
|
||||
Raises:
|
||||
Exception: If the task is not found or cannot be cancelled
|
||||
"""
|
||||
task_data = await self.cache.get(f"task:{task_id}")
|
||||
if not task_data:
|
||||
raise Exception(f"Task {task_id} not found")
|
||||
|
||||
# Check if the task is in a state that can be cancelled
|
||||
current_state = task_data["status"]["state"]
|
||||
if current_state not in ["submitted", "running"]:
|
||||
raise Exception(f"Cannot cancel task in {current_state} state")
|
||||
|
||||
# Cancel the task in the runner if it is running
|
||||
running_task = self._running_tasks.get(task_id)
|
||||
if running_task:
|
||||
# Try to cancel the running task
|
||||
if hasattr(running_task, "cancel"):
|
||||
running_task.cancel()
|
||||
|
||||
# Update status to "cancelled"
|
||||
task_data["status"].update(
|
||||
{
|
||||
"state": "cancelled",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": {
|
||||
"role": "system",
|
||||
"parts": [{"type": "text", "text": "Task cancelled by user"}],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Update cache
|
||||
await self.cache.set(f"task:{task_id}", task_data)
|
||||
|
||||
# Send push notification if configured
|
||||
push_config = await self.cache.get(f"task:{task_id}:push")
|
||||
if push_config and self.push_notification_service:
|
||||
await self.push_notification_service.send_notification(
|
||||
url=push_config["url"],
|
||||
task_id=task_id,
|
||||
state="cancelled",
|
||||
message={
|
||||
"role": "system",
|
||||
"parts": [{"type": "text", "text": "Task cancelled by user"}],
|
||||
},
|
||||
headers=push_config.get("headers", {}),
|
||||
)
|
||||
|
||||
return task_data
|
||||
|
||||
async def on_set_task_push_notification(
|
||||
self, task_id: str, notification_config: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Configure push notifications for a task.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
notification_config: Notification configuration (URL and headers)
|
||||
|
||||
Returns:
|
||||
Updated configuration
|
||||
"""
|
||||
# Validate configuration
|
||||
url = notification_config.get("url")
|
||||
if not url:
|
||||
raise ValueError("Push notification URL is required")
|
||||
|
||||
headers = notification_config.get("headers", {})
|
||||
|
||||
# Store configuration
|
||||
config = {"url": url, "headers": headers}
|
||||
await self.cache.set(f"task:{task_id}:push", config)
|
||||
|
||||
return config
|
||||
|
||||
async def on_get_task_push_notification(self, task_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the push notification configuration for a task.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
Push notification configuration
|
||||
|
||||
Raises:
|
||||
Exception: If there is no configuration for the task
|
||||
"""
|
||||
config = await self.cache.get(f"task:{task_id}:push")
|
||||
if not config:
|
||||
raise Exception(f"No push notification configuration for task {task_id}")
|
||||
return config
|
||||
|
||||
|
||||
class A2AServer:
|
||||
"""
|
||||
A2A server compatible with JSON-RPC 2.0.
|
||||
|
||||
This class processes JSON-RPC requests and forwards them to
|
||||
the appropriate handlers in the A2ATaskManager.
|
||||
"""
|
||||
|
||||
def __init__(self, task_manager: A2ATaskManager, agent_card=None):
|
||||
"""
|
||||
Initialize the A2A server.
|
||||
|
||||
Args:
|
||||
task_manager: Task manager
|
||||
agent_card: Agent card information
|
||||
"""
|
||||
self.task_manager = task_manager
|
||||
self.agent_card = agent_card
|
||||
|
||||
async def process_request(
|
||||
self,
|
||||
request: Request,
|
||||
agent_id: Optional[str] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> Union[Response, JSONResponse, StreamingResponse]:
|
||||
"""
|
||||
Process a JSON-RPC request.
|
||||
|
||||
Args:
|
||||
request: HTTP request
|
||||
agent_id: Optional agent ID to inject into the request
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Appropriate response (JSON or Streaming)
|
||||
"""
|
||||
try:
|
||||
# Try to parse the JSON payload
|
||||
try:
|
||||
logger.info("Starting JSON-RPC request processing")
|
||||
body = await request.json()
|
||||
logger.info(f"Received JSON data: {json.dumps(body)}")
|
||||
method = body.get("method", "unknown")
|
||||
|
||||
# Validate the request using the A2A validator
|
||||
json_rpc_request = A2ARequest.validate_python(body)
|
||||
|
||||
original_db = self.task_manager.db
|
||||
try:
|
||||
# Set the db temporarily
|
||||
if db is not None:
|
||||
self.task_manager.db = db
|
||||
|
||||
# Process the request
|
||||
if isinstance(json_rpc_request, SendTaskRequest):
|
||||
json_rpc_request.params.agentId = agent_id
|
||||
result = await self.task_manager.on_send_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, SendTaskStreamingRequest):
|
||||
json_rpc_request.params.agentId = agent_id
|
||||
result = await self.task_manager.on_send_task_subscribe(
|
||||
json_rpc_request
|
||||
)
|
||||
elif isinstance(json_rpc_request, GetTaskRequest):
|
||||
result = await self.task_manager.on_get_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, CancelTaskRequest):
|
||||
result = await self.task_manager.on_cancel_task(
|
||||
json_rpc_request
|
||||
)
|
||||
elif isinstance(json_rpc_request, SetTaskPushNotificationRequest):
|
||||
result = await self.task_manager.on_set_task_push_notification(
|
||||
json_rpc_request
|
||||
)
|
||||
elif isinstance(json_rpc_request, GetTaskPushNotificationRequest):
|
||||
result = await self.task_manager.on_get_task_push_notification(
|
||||
json_rpc_request
|
||||
)
|
||||
elif isinstance(json_rpc_request, TaskResubscriptionRequest):
|
||||
result = await self.task_manager.on_resubscribe_to_task(
|
||||
json_rpc_request
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SERVER] Request type not supported: {type(json_rpc_request)}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": body.get("id"),
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": "Method not found",
|
||||
"data": {
|
||||
"detail": f"Method not supported: {method}"
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
finally:
|
||||
# Restore the original db
|
||||
self.task_manager.db = original_db
|
||||
|
||||
# Create appropriate response
|
||||
return self._create_response(result)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
# Error parsing JSON
|
||||
logger.error(f"Error parsing JSON request: {str(e)}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": None,
|
||||
"error": {
|
||||
"code": -32700,
|
||||
"message": "Parse error",
|
||||
"data": {"detail": str(e)},
|
||||
},
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
# Other validation errors
|
||||
logger.error(f"Error validating request: {str(e)}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": body.get("id") if "body" in locals() else None,
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": "Invalid Request",
|
||||
"data": {"detail": str(e)},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing JSON-RPC request: {str(e)}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": None,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": "Internal error",
|
||||
"data": {"detail": str(e)},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def _create_response(self, result: Any) -> Union[JSONResponse, StreamingResponse]:
|
||||
"""
|
||||
Create appropriate response based on result type.
|
||||
|
||||
Args:
|
||||
result: Result from task manager
|
||||
|
||||
Returns:
|
||||
JSON or streaming response
|
||||
"""
|
||||
if isinstance(result, AsyncIterable):
|
||||
# Result in streaming (SSE)
|
||||
async def event_generator():
|
||||
async for item in result:
|
||||
if hasattr(item, "model_dump_json"):
|
||||
yield {"data": item.model_dump_json(exclude_none=True)}
|
||||
else:
|
||||
yield {"data": json.dumps(item)}
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
elif hasattr(result, "model_dump"):
|
||||
# Result is a Pydantic object
|
||||
return JSONResponse(result.model_dump(exclude_none=True))
|
||||
|
||||
else:
|
||||
# Result is a dictionary or other simple type
|
||||
return JSONResponse(result)
|
||||
|
||||
async def get_agent_card(
|
||||
self, request: Request, db: Optional[Session] = None
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Get the agent card.
|
||||
|
||||
Args:
|
||||
request: HTTP request
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Agent card as JSON
|
||||
"""
|
||||
if not self.agent_card:
|
||||
logger.error("Agent card not configured")
|
||||
return JSONResponse(
|
||||
status_code=404, content={"error": "Agent card not configured"}
|
||||
)
|
||||
|
||||
# If there is db, set it temporarily in the task_manager
|
||||
if db and hasattr(self.task_manager, "db"):
|
||||
original_db = self.task_manager.db
|
||||
try:
|
||||
self.task_manager.db = db
|
||||
|
||||
# If it's a Pydantic object, convert to dictionary
|
||||
if hasattr(self.agent_card, "model_dump"):
|
||||
return JSONResponse(self.agent_card.model_dump(exclude_none=True))
|
||||
else:
|
||||
return JSONResponse(self.agent_card)
|
||||
finally:
|
||||
# Restore the original db
|
||||
self.task_manager.db = original_db
|
||||
else:
|
||||
# If it's a Pydantic object, convert to dictionary
|
||||
if hasattr(self.agent_card, "model_dump"):
|
||||
return JSONResponse(self.agent_card.model_dump(exclude_none=True))
|
||||
else:
|
||||
return JSONResponse(self.agent_card)
|
644
src/services/a2a_task_manager.py
Normal file
644
src/services/a2a_task_manager.py
Normal file
@ -0,0 +1,644 @@
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from typing import Any, Dict, Optional, Union, List
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.settings import settings
|
||||
from src.services.agent_service import (
|
||||
get_agent,
|
||||
create_agent,
|
||||
update_agent,
|
||||
delete_agent,
|
||||
get_agents_by_client,
|
||||
)
|
||||
from src.services.mcp_server_service import get_mcp_server
|
||||
from src.services.session_service import (
|
||||
get_sessions_by_client,
|
||||
get_sessions_by_agent,
|
||||
get_session_by_id,
|
||||
delete_session,
|
||||
get_session_events,
|
||||
)
|
||||
|
||||
from src.services.agent_runner import run_agent, run_agent_stream
|
||||
from src.services.service_providers import (
|
||||
session_service,
|
||||
artifacts_service,
|
||||
memory_service,
|
||||
)
|
||||
from src.models.models import Agent
|
||||
from src.schemas.a2a_types import (
|
||||
A2ARequest,
|
||||
GetTaskRequest,
|
||||
SendTaskRequest,
|
||||
SendTaskResponse,
|
||||
SendTaskStreamingRequest,
|
||||
SendTaskStreamingResponse,
|
||||
CancelTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationRequest,
|
||||
TaskResubscriptionRequest,
|
||||
JSONRPCResponse,
|
||||
TaskStatusUpdateEvent,
|
||||
TaskArtifactUpdateEvent,
|
||||
Task,
|
||||
TaskSendParams,
|
||||
InternalError,
|
||||
Message,
|
||||
Artifact,
|
||||
TaskStatus,
|
||||
TaskState,
|
||||
AgentCard,
|
||||
AgentCapabilities,
|
||||
AgentSkill,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class A2ATaskManager:
|
||||
"""Task manager for the A2A protocol."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.tasks: Dict[str, Task] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def upsert_task(self, task_params: TaskSendParams) -> Task:
|
||||
"""Creates or updates a task in the store."""
|
||||
async with self.lock:
|
||||
task = self.tasks.get(task_params.id)
|
||||
if task is None:
|
||||
# Create new task with initial history
|
||||
task = Task(
|
||||
id=task_params.id,
|
||||
sessionId=task_params.sessionId,
|
||||
status=TaskStatus(state=TaskState.SUBMITTED),
|
||||
history=[task_params.message],
|
||||
artifacts=[],
|
||||
)
|
||||
self.tasks[task_params.id] = task
|
||||
else:
|
||||
# Add message to existing history
|
||||
if task.history is None:
|
||||
task.history = []
|
||||
task.history.append(task_params.message)
|
||||
|
||||
return task
|
||||
|
||||
async def on_get_task(self, request: GetTaskRequest) -> JSONRPCResponse:
|
||||
"""Handles requests to get task details."""
|
||||
try:
|
||||
task_id = request.params.id
|
||||
history_length = request.params.historyLength
|
||||
|
||||
async with self.lock:
|
||||
if task_id not in self.tasks:
|
||||
return JSONRPCResponse(
|
||||
id=request.id,
|
||||
error=InternalError(message=f"Task {task_id} not found"),
|
||||
)
|
||||
|
||||
# Get the task and limit the history as requested
|
||||
task = self.tasks[task_id]
|
||||
task_result = self.append_task_history(task, history_length)
|
||||
|
||||
return SendTaskResponse(id=request.id, result=task_result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting task: {e}")
|
||||
return JSONRPCResponse(
|
||||
id=request.id,
|
||||
error=InternalError(message=f"Error getting task: {str(e)}"),
|
||||
)
|
||||
|
||||
async def on_send_task(
|
||||
self, request: SendTaskRequest, agent_id: UUID
|
||||
) -> JSONRPCResponse:
|
||||
"""Handles requests to send a task for processing."""
|
||||
try:
|
||||
agent = get_agent(self.db, agent_id)
|
||||
if not agent:
|
||||
return JSONRPCResponse(
|
||||
id=request.id,
|
||||
error=InternalError(message=f"Agent {agent_id} not found"),
|
||||
)
|
||||
|
||||
await self.upsert_task(request.params)
|
||||
return await self._process_task(request, agent)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending task: {e}")
|
||||
return JSONRPCResponse(
|
||||
id=request.id,
|
||||
error=InternalError(message=f"Error sending task: {str(e)}"),
|
||||
)
|
||||
|
||||
async def on_send_task_subscribe(
|
||||
self, request: SendTaskStreamingRequest, agent_id: UUID
|
||||
) -> AsyncIterable[SendTaskStreamingResponse]:
|
||||
"""Handles requests to send a task and subscribe to updates."""
|
||||
try:
|
||||
agent = get_agent(self.db, agent_id)
|
||||
if not agent:
|
||||
yield JSONRPCResponse(
|
||||
id=request.id,
|
||||
error=InternalError(message=f"Agent {agent_id} not found"),
|
||||
)
|
||||
return
|
||||
|
||||
await self.upsert_task(request.params)
|
||||
async for response in self._stream_task_process(request, agent):
|
||||
yield response
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing streaming task: {e}")
|
||||
yield JSONRPCResponse(
|
||||
id=request.id,
|
||||
error=InternalError(
|
||||
message=f"Error processing streaming task: {str(e)}"
|
||||
),
|
||||
)
|
||||
|
||||
async def on_cancel_task(self, request: CancelTaskRequest) -> JSONRPCResponse:
|
||||
"""Handles requests to cancel a task."""
|
||||
try:
|
||||
task_id = request.params.id
|
||||
async with self.lock:
|
||||
if task_id in self.tasks:
|
||||
task = self.tasks[task_id]
|
||||
task.status = TaskStatus(state=TaskState.CANCELED)
|
||||
return JSONRPCResponse(id=request.id, result=True)
|
||||
return JSONRPCResponse(
|
||||
id=request.id,
|
||||
error=InternalError(message=f"Task {task_id} not found"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error canceling task: {e}")
|
||||
return JSONRPCResponse(
|
||||
id=request.id,
|
||||
error=InternalError(message=f"Error canceling task: {str(e)}"),
|
||||
)
|
||||
|
||||
async def on_set_task_push_notification(
|
||||
self, request: SetTaskPushNotificationRequest
|
||||
) -> JSONRPCResponse:
|
||||
"""Handles requests to configure push notifications for a task."""
|
||||
return JSONRPCResponse(id=request.id, result=True)
|
||||
|
||||
async def on_get_task_push_notification(
|
||||
self, request: GetTaskPushNotificationRequest
|
||||
) -> JSONRPCResponse:
|
||||
"""Handles requests to get push notification settings for a task."""
|
||||
return JSONRPCResponse(id=request.id, result={})
|
||||
|
||||
async def on_resubscribe_to_task(
|
||||
self, request: TaskResubscriptionRequest
|
||||
) -> AsyncIterable[SendTaskStreamingResponse]:
|
||||
"""Handles requests to resubscribe to a task."""
|
||||
task_id = request.params.id
|
||||
try:
|
||||
async with self.lock:
|
||||
if task_id not in self.tasks:
|
||||
yield SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
error=InternalError(message=f"Task {task_id} not found"),
|
||||
)
|
||||
return
|
||||
|
||||
task = self.tasks[task_id]
|
||||
|
||||
yield SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
result=TaskStatusUpdateEvent(
|
||||
id=task_id,
|
||||
status=task.status,
|
||||
final=False,
|
||||
),
|
||||
)
|
||||
|
||||
if task.artifacts:
|
||||
for artifact in task.artifacts:
|
||||
yield SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
result=TaskArtifactUpdateEvent(id=task_id, artifact=artifact),
|
||||
)
|
||||
|
||||
yield SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
result=TaskStatusUpdateEvent(
|
||||
id=task_id,
|
||||
status=TaskStatus(state=task.status.state),
|
||||
final=True,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error resubscribing to task: {e}")
|
||||
yield SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
error=InternalError(message=f"Error resubscribing to task: {str(e)}"),
|
||||
)
|
||||
|
||||
async def _process_task(
|
||||
self, request: SendTaskRequest, agent: Agent
|
||||
) -> JSONRPCResponse:
|
||||
"""Processes a task using the specified agent."""
|
||||
task_params = request.params
|
||||
query = self._extract_user_query(task_params)
|
||||
|
||||
try:
|
||||
# Process the query with the agent
|
||||
result = await self._run_agent(agent, query, task_params.sessionId)
|
||||
|
||||
# Create the response part
|
||||
text_part = {"type": "text", "text": result}
|
||||
parts = [text_part]
|
||||
agent_message = Message(role="agent", parts=parts)
|
||||
|
||||
# Determine the task state
|
||||
task_state = (
|
||||
TaskState.INPUT_REQUIRED
|
||||
if "MISSING_INFO:" in result
|
||||
else TaskState.COMPLETED
|
||||
)
|
||||
|
||||
# Update the task in the store
|
||||
task = await self.update_store(
|
||||
task_params.id,
|
||||
TaskStatus(state=task_state, message=agent_message),
|
||||
[Artifact(parts=parts, index=0)],
|
||||
)
|
||||
|
||||
return SendTaskResponse(id=request.id, result=task)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing task: {e}")
|
||||
return JSONRPCResponse(
|
||||
id=request.id,
|
||||
error=InternalError(message=f"Error processing task: {str(e)}"),
|
||||
)
|
||||
|
||||
async def _stream_task_process(
|
||||
self, request: SendTaskStreamingRequest, agent: Agent
|
||||
) -> AsyncIterable[SendTaskStreamingResponse]:
|
||||
"""Processes a task in streaming mode using the specified agent."""
|
||||
task_params = request.params
|
||||
query = self._extract_user_query(task_params)
|
||||
|
||||
try:
|
||||
# Send initial processing status
|
||||
processing_text_part = {
|
||||
"type": "text",
|
||||
"text": "Processing your request...",
|
||||
}
|
||||
processing_message = Message(role="agent", parts=[processing_text_part])
|
||||
|
||||
# Update the task with the processing message and inform the WORKING state
|
||||
await self.update_store(
|
||||
task_params.id,
|
||||
TaskStatus(state=TaskState.WORKING, message=processing_message),
|
||||
)
|
||||
|
||||
yield SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
result=TaskStatusUpdateEvent(
|
||||
id=task_params.id,
|
||||
status=TaskStatus(
|
||||
state=TaskState.WORKING,
|
||||
message=processing_message,
|
||||
),
|
||||
final=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Collect the chunks of the agent's response
|
||||
contact_id = task_params.sessionId
|
||||
full_response = ""
|
||||
|
||||
# We use the same streaming function used in the WebSocket
|
||||
async for chunk in run_agent_stream(
|
||||
agent_id=str(agent.id),
|
||||
contact_id=contact_id,
|
||||
message=query,
|
||||
session_service=session_service,
|
||||
artifacts_service=artifacts_service,
|
||||
memory_service=memory_service,
|
||||
db=self.db,
|
||||
):
|
||||
# Send incremental progress updates
|
||||
update_text_part = {"type": "text", "text": chunk}
|
||||
update_message = Message(role="agent", parts=[update_text_part])
|
||||
|
||||
# Update the task with each intermediate message
|
||||
await self.update_store(
|
||||
task_params.id,
|
||||
TaskStatus(state=TaskState.WORKING, message=update_message),
|
||||
)
|
||||
|
||||
yield SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
result=TaskStatusUpdateEvent(
|
||||
id=task_params.id,
|
||||
status=TaskStatus(
|
||||
state=TaskState.WORKING,
|
||||
message=update_message,
|
||||
),
|
||||
final=False,
|
||||
),
|
||||
)
|
||||
full_response += chunk
|
||||
|
||||
# Determine the task state
|
||||
task_state = (
|
||||
TaskState.INPUT_REQUIRED
|
||||
if "MISSING_INFO:" in full_response
|
||||
else TaskState.COMPLETED
|
||||
)
|
||||
|
||||
# Create the final response part
|
||||
final_text_part = {"type": "text", "text": full_response}
|
||||
parts = [final_text_part]
|
||||
final_message = Message(role="agent", parts=parts)
|
||||
|
||||
# Create the final artifact from the final response
|
||||
final_artifact = Artifact(parts=parts, index=0)
|
||||
|
||||
# Update the task in the store with the final response
|
||||
task = await self.update_store(
|
||||
task_params.id,
|
||||
TaskStatus(state=task_state, message=final_message),
|
||||
[final_artifact],
|
||||
)
|
||||
|
||||
# Send the final artifact
|
||||
yield SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
result=TaskArtifactUpdateEvent(
|
||||
id=task_params.id, artifact=final_artifact
|
||||
),
|
||||
)
|
||||
|
||||
# Send the final status
|
||||
yield SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
result=TaskStatusUpdateEvent(
|
||||
id=task_params.id,
|
||||
status=TaskStatus(state=task_state),
|
||||
final=True,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming task process: {e}")
|
||||
yield JSONRPCResponse(
|
||||
id=request.id,
|
||||
error=InternalError(message=f"Error streaming task process: {str(e)}"),
|
||||
)
|
||||
|
||||
async def update_store(
|
||||
self,
|
||||
task_id: str,
|
||||
status: TaskStatus,
|
||||
artifacts: Optional[list[Artifact]] = None,
|
||||
) -> Task:
|
||||
"""Updates the status and artifacts of a task."""
|
||||
async with self.lock:
|
||||
if task_id not in self.tasks:
|
||||
raise ValueError(f"Task {task_id} not found")
|
||||
|
||||
task = self.tasks[task_id]
|
||||
task.status = status
|
||||
|
||||
# Add message to history if it exists
|
||||
if status.message is not None:
|
||||
if task.history is None:
|
||||
task.history = []
|
||||
task.history.append(status.message)
|
||||
|
||||
if artifacts is not None:
|
||||
if task.artifacts is None:
|
||||
task.artifacts = []
|
||||
task.artifacts.extend(artifacts)
|
||||
|
||||
return task
|
||||
|
||||
def _extract_user_query(self, task_params: TaskSendParams) -> str:
|
||||
"""Extracts the user query from the task parameters."""
|
||||
if not task_params.message or not task_params.message.parts:
|
||||
raise ValueError("Message or parts are missing in task parameters")
|
||||
|
||||
part = task_params.message.parts[0]
|
||||
if part.type != "text":
|
||||
raise ValueError("Only text parts are supported")
|
||||
|
||||
return part.text
|
||||
|
||||
async def _run_agent(self, agent: Agent, query: str, session_id: str) -> str:
|
||||
"""Executes the agent to process the user query."""
|
||||
try:
|
||||
# We use the session_id as contact_id to maintain the conversation continuity
|
||||
contact_id = session_id
|
||||
|
||||
# We call the same function used in the chat API
|
||||
final_response = await run_agent(
|
||||
agent_id=str(agent.id),
|
||||
contact_id=contact_id,
|
||||
message=query,
|
||||
session_service=session_service,
|
||||
artifacts_service=artifacts_service,
|
||||
memory_service=memory_service,
|
||||
db=self.db,
|
||||
)
|
||||
|
||||
return final_response
|
||||
except Exception as e:
|
||||
logger.error(f"Error running agent: {e}")
|
||||
raise ValueError(f"Error running agent: {str(e)}")
|
||||
|
||||
def append_task_history(self, task: Task, history_length: int | None) -> Task:
|
||||
"""Returns a copy of the task with the history limited to the specified size."""
|
||||
# Create a copy of the task
|
||||
new_task = task.model_copy()
|
||||
|
||||
# Limit the history if requested
|
||||
if history_length is not None:
|
||||
if history_length > 0:
|
||||
new_task.history = (
|
||||
new_task.history[-history_length:] if new_task.history else []
|
||||
)
|
||||
else:
|
||||
new_task.history = []
|
||||
|
||||
return new_task
|
||||
|
||||
|
||||
class A2AService:
|
||||
"""Service to manage A2A requests and agent cards."""
|
||||
|
||||
def __init__(self, db: Session, task_manager: A2ATaskManager):
|
||||
self.db = db
|
||||
self.task_manager = task_manager
|
||||
|
||||
async def process_request(
|
||||
self, agent_id: UUID, request_body: dict
|
||||
) -> JSONRPCResponse:
|
||||
"""Processes an A2A request."""
|
||||
try:
|
||||
request = A2ARequest.validate_python(request_body)
|
||||
|
||||
if isinstance(request, GetTaskRequest):
|
||||
return await self.task_manager.on_get_task(request)
|
||||
elif isinstance(request, SendTaskRequest):
|
||||
return await self.task_manager.on_send_task(request, agent_id)
|
||||
elif isinstance(request, SendTaskStreamingRequest):
|
||||
return self.task_manager.on_send_task_subscribe(request, agent_id)
|
||||
elif isinstance(request, CancelTaskRequest):
|
||||
return await self.task_manager.on_cancel_task(request)
|
||||
elif isinstance(request, SetTaskPushNotificationRequest):
|
||||
return await self.task_manager.on_set_task_push_notification(request)
|
||||
elif isinstance(request, GetTaskPushNotificationRequest):
|
||||
return await self.task_manager.on_get_task_push_notification(request)
|
||||
elif isinstance(request, TaskResubscriptionRequest):
|
||||
return self.task_manager.on_resubscribe_to_task(request)
|
||||
else:
|
||||
logger.warning(f"Unexpected request type: {type(request)}")
|
||||
return JSONRPCResponse(
|
||||
id=getattr(request, "id", None),
|
||||
error=InternalError(
|
||||
message=f"Unexpected request type: {type(request)}"
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing A2A request: {e}")
|
||||
return JSONRPCResponse(
|
||||
id=None,
|
||||
error=InternalError(message=f"Error processing A2A request: {str(e)}"),
|
||||
)
|
||||
|
||||
def get_agent_card(self, agent_id: UUID) -> AgentCard:
|
||||
"""Gets the agent card for the specified agent."""
|
||||
agent = get_agent(self.db, agent_id)
|
||||
if not agent:
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
# Build the agent card based on the agent's information
|
||||
capabilities = AgentCapabilities(streaming=True)
|
||||
|
||||
# List to store all skills
|
||||
skills = []
|
||||
|
||||
# Check if the agent has MCP servers configured
|
||||
if (
|
||||
agent.config
|
||||
and "mcp_servers" in agent.config
|
||||
and agent.config["mcp_servers"]
|
||||
):
|
||||
logger.info(
|
||||
f"Agent {agent_id} has {len(agent.config['mcp_servers'])} MCP servers configured"
|
||||
)
|
||||
|
||||
for mcp_config in agent.config["mcp_servers"]:
|
||||
# Get the MCP server
|
||||
mcp_server_id = mcp_config.get("id")
|
||||
if not mcp_server_id:
|
||||
logger.warning("MCP server configuration missing ID")
|
||||
continue
|
||||
|
||||
logger.info(f"Processing MCP server: {mcp_server_id}")
|
||||
mcp_server = get_mcp_server(self.db, mcp_server_id)
|
||||
if not mcp_server:
|
||||
logger.warning(f"MCP server {mcp_server_id} not found")
|
||||
continue
|
||||
|
||||
# Get the available tools in the MCP server
|
||||
mcp_tools = mcp_config.get("tools", [])
|
||||
logger.info(f"MCP server {mcp_server.name} has tools: {mcp_tools}")
|
||||
|
||||
# Add server tools as skills
|
||||
for tool_name in mcp_tools:
|
||||
logger.info(f"Processing tool: {tool_name}")
|
||||
|
||||
# Buscar informações da ferramenta pelo ID
|
||||
tool_info = None
|
||||
if hasattr(mcp_server, "tools") and isinstance(
|
||||
mcp_server.tools, list
|
||||
):
|
||||
for tool in mcp_server.tools:
|
||||
if isinstance(tool, dict) and tool.get("id") == tool_name:
|
||||
tool_info = tool
|
||||
logger.info(
|
||||
f"Found tool info for {tool_name}: {tool_info}"
|
||||
)
|
||||
break
|
||||
|
||||
if tool_info:
|
||||
# Use the information from the tool
|
||||
skill = AgentSkill(
|
||||
id=tool_info.get("id", f"{agent.id}_{tool_name}"),
|
||||
name=tool_info.get("name", tool_name),
|
||||
description=tool_info.get(
|
||||
"description", f"Tool: {tool_name}"
|
||||
),
|
||||
tags=tool_info.get(
|
||||
"tags", [mcp_server.name, "tool", tool_name]
|
||||
),
|
||||
examples=tool_info.get("examples", []),
|
||||
inputModes=tool_info.get("inputModes", ["text"]),
|
||||
outputModes=tool_info.get("outputModes", ["text"]),
|
||||
)
|
||||
else:
|
||||
# Default skill if tool info not found
|
||||
skill = AgentSkill(
|
||||
id=f"{agent.id}_{tool_name}",
|
||||
name=tool_name,
|
||||
description=f"Tool: {tool_name}",
|
||||
tags=[mcp_server.name, "tool", tool_name],
|
||||
examples=[],
|
||||
inputModes=["text"],
|
||||
outputModes=["text"],
|
||||
)
|
||||
|
||||
skills.append(skill)
|
||||
logger.info(f"Added skill for tool: {tool_name}")
|
||||
|
||||
# Check custom tools
|
||||
if (
|
||||
agent.config
|
||||
and "custom_tools" in agent.config
|
||||
and agent.config["custom_tools"]
|
||||
):
|
||||
custom_tools = agent.config["custom_tools"]
|
||||
|
||||
# Check HTTP tools
|
||||
if "http_tools" in custom_tools and custom_tools["http_tools"]:
|
||||
logger.info(f"Agent has {len(custom_tools['http_tools'])} HTTP tools")
|
||||
for http_tool in custom_tools["http_tools"]:
|
||||
skill = AgentSkill(
|
||||
id=f"{agent.id}_http_{http_tool['name']}",
|
||||
name=http_tool["name"],
|
||||
description=http_tool.get(
|
||||
"description", f"HTTP Tool: {http_tool['name']}"
|
||||
),
|
||||
tags=http_tool.get(
|
||||
"tags", ["http", "custom_tool", http_tool["method"]]
|
||||
),
|
||||
examples=http_tool.get("examples", []),
|
||||
inputModes=http_tool.get("inputModes", ["text"]),
|
||||
outputModes=http_tool.get("outputModes", ["text"]),
|
||||
)
|
||||
skills.append(skill)
|
||||
logger.info(f"Added skill for HTTP tool: {http_tool['name']}")
|
||||
|
||||
card = AgentCard(
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
url=f"{settings.API_URL}/api/v1/a2a/{agent_id}",
|
||||
version="1.0.0",
|
||||
defaultInputModes=["text"],
|
||||
defaultOutputModes=["text"],
|
||||
capabilities=capabilities,
|
||||
skills=skills,
|
||||
)
|
||||
|
||||
logger.info(f"Generated agent card with {len(skills)} skills")
|
||||
return card
|
@ -1,949 +0,0 @@
|
||||
"""
|
||||
A2A Task Manager Service.
|
||||
|
||||
This service implements task management for the A2A protocol, handling task lifecycle
|
||||
including execution, streaming, push notifications, status queries, and cancellation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Union, AsyncIterable
|
||||
import uuid
|
||||
|
||||
from src.schemas.a2a.exceptions import (
|
||||
TaskNotFoundError,
|
||||
TaskNotCancelableError,
|
||||
PushNotificationNotSupportedError,
|
||||
InternalError,
|
||||
ContentTypeNotSupportedError,
|
||||
)
|
||||
|
||||
from src.schemas.a2a.types import (
|
||||
JSONRPCResponse,
|
||||
GetTaskRequest,
|
||||
SendTaskRequest,
|
||||
CancelTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationRequest,
|
||||
GetTaskResponse,
|
||||
CancelTaskResponse,
|
||||
SendTaskResponse,
|
||||
SetTaskPushNotificationResponse,
|
||||
GetTaskPushNotificationResponse,
|
||||
TaskSendParams,
|
||||
TaskStatus,
|
||||
TaskState,
|
||||
TaskResubscriptionRequest,
|
||||
SendTaskStreamingRequest,
|
||||
SendTaskStreamingResponse,
|
||||
Artifact,
|
||||
PushNotificationConfig,
|
||||
TaskStatusUpdateEvent,
|
||||
TaskArtifactUpdateEvent,
|
||||
JSONRPCError,
|
||||
TaskPushNotificationConfig,
|
||||
Message,
|
||||
TextPart,
|
||||
Task,
|
||||
)
|
||||
from src.services.redis_cache_service import RedisCacheService
|
||||
from src.utils.a2a_utils import (
|
||||
are_modalities_compatible,
|
||||
new_incompatible_types_error,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class A2ATaskManager:
|
||||
"""
|
||||
A2A Task Manager implementation.
|
||||
|
||||
This class manages the lifecycle of A2A tasks, including:
|
||||
- Task submission and execution
|
||||
- Task status queries
|
||||
- Task cancellation
|
||||
- Push notification configuration
|
||||
- SSE streaming for real-time updates
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_cache: RedisCacheService,
|
||||
agent_runner=None,
|
||||
streaming_service=None,
|
||||
push_notification_service=None,
|
||||
db=None,
|
||||
):
|
||||
"""
|
||||
Initialize the A2A Task Manager.
|
||||
|
||||
Args:
|
||||
redis_cache: Redis cache service for task storage
|
||||
agent_runner: Agent runner service for task execution
|
||||
streaming_service: Streaming service for SSE
|
||||
push_notification_service: Service for push notifications
|
||||
db: Database session
|
||||
"""
|
||||
self.redis_cache = redis_cache
|
||||
self.agent_runner = agent_runner
|
||||
self.streaming_service = streaming_service
|
||||
self.push_notification_service = push_notification_service
|
||||
self.db = db
|
||||
self.lock = asyncio.Lock()
|
||||
self.subscriber_lock = asyncio.Lock()
|
||||
self.task_sse_subscribers = {}
|
||||
|
||||
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
|
||||
"""
|
||||
Handle request to get task information.
|
||||
|
||||
Args:
|
||||
request: A2A Get Task request
|
||||
|
||||
Returns:
|
||||
Response with task details
|
||||
"""
|
||||
try:
|
||||
task_id = request.params.id
|
||||
history_length = request.params.historyLength
|
||||
|
||||
# Get task data from cache
|
||||
task_data = await self.redis_cache.get(f"task:{task_id}")
|
||||
|
||||
if not task_data:
|
||||
logger.warning(f"Task not found: {task_id}")
|
||||
return GetTaskResponse(id=request.id, error=TaskNotFoundError())
|
||||
|
||||
# Create a Task instance from cache data
|
||||
task = Task.model_validate(task_data)
|
||||
|
||||
# If historyLength parameter is present, handle the history
|
||||
if history_length is not None and task.history:
|
||||
if history_length == 0:
|
||||
task.history = []
|
||||
elif len(task.history) > history_length:
|
||||
task.history = task.history[-history_length:]
|
||||
|
||||
return GetTaskResponse(id=request.id, result=task)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing on_get_task: {str(e)}")
|
||||
return GetTaskResponse(id=request.id, error=InternalError(message=str(e)))
|
||||
|
||||
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
|
||||
"""
|
||||
Handle request to cancel a running task.
|
||||
|
||||
Args:
|
||||
request: The JSON-RPC request to cancel a task
|
||||
|
||||
Returns:
|
||||
Response with updated task data or error
|
||||
"""
|
||||
logger.info(f"Cancelling task {request.params.id}")
|
||||
task_id_params = request.params
|
||||
|
||||
try:
|
||||
task_data = await self.redis_cache.get(f"task:{task_id_params.id}")
|
||||
if not task_data:
|
||||
logger.warning(f"Task {task_id_params.id} not found for cancellation")
|
||||
return CancelTaskResponse(id=request.id, error=TaskNotFoundError())
|
||||
|
||||
# Check if task can be cancelled
|
||||
current_state = task_data.get("status", {}).get("state")
|
||||
if current_state not in [TaskState.SUBMITTED, TaskState.WORKING]:
|
||||
logger.warning(
|
||||
f"Task {task_id_params.id} in state {current_state} cannot be cancelled"
|
||||
)
|
||||
return CancelTaskResponse(id=request.id, error=TaskNotCancelableError())
|
||||
|
||||
# Update task status to cancelled
|
||||
task_data["status"] = {
|
||||
"state": TaskState.CANCELED,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": {
|
||||
"role": "agent",
|
||||
"parts": [{"type": "text", "text": "Task cancelled by user"}],
|
||||
},
|
||||
}
|
||||
|
||||
# Save updated task data
|
||||
await self.redis_cache.set(f"task:{task_id_params.id}", task_data)
|
||||
|
||||
# Send push notification if configured
|
||||
await self._send_push_notification_for_task(
|
||||
task_id_params.id, "canceled", system_message="Task cancelled by user"
|
||||
)
|
||||
|
||||
# Publish event to SSE subscribers
|
||||
await self._publish_task_update(
|
||||
task_id_params.id,
|
||||
TaskStatusUpdateEvent(
|
||||
id=task_id_params.id,
|
||||
status=TaskStatus(
|
||||
state=TaskState.CANCELED,
|
||||
timestamp=datetime.now(),
|
||||
message=Message(
|
||||
role="agent",
|
||||
parts=[TextPart(text="Task cancelled by user")],
|
||||
),
|
||||
),
|
||||
final=True,
|
||||
),
|
||||
)
|
||||
|
||||
return CancelTaskResponse(id=request.id, result=task_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling task: {str(e)}", exc_info=True)
|
||||
return CancelTaskResponse(
|
||||
id=request.id,
|
||||
error=InternalError(message=f"Error cancelling task: {str(e)}"),
|
||||
)
|
||||
|
||||
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
|
||||
"""
|
||||
Handle request to send a new task.
|
||||
|
||||
Args:
|
||||
request: Send Task request
|
||||
|
||||
Returns:
|
||||
Response with the created task details
|
||||
"""
|
||||
try:
|
||||
params = request.params
|
||||
task_id = params.id
|
||||
logger.info(f"Receiving task {task_id}")
|
||||
|
||||
# Check if a task with this ID already exists
|
||||
existing_task = await self.redis_cache.get(f"task:{task_id}")
|
||||
if existing_task:
|
||||
# If the task already exists and is in progress, return the current task
|
||||
if existing_task.get("status", {}).get("state") in [
|
||||
TaskState.WORKING,
|
||||
TaskState.COMPLETED,
|
||||
]:
|
||||
return SendTaskResponse(
|
||||
id=request.id, result=Task.model_validate(existing_task)
|
||||
)
|
||||
|
||||
# If the task exists but failed or was canceled, we can reprocess it
|
||||
logger.info(f"Reprocessing existing task {task_id}")
|
||||
|
||||
# Check modality compatibility
|
||||
server_output_modes = []
|
||||
if self.agent_runner:
|
||||
# Try to get supported modes from the agent
|
||||
try:
|
||||
server_output_modes = await self.agent_runner.get_supported_modes()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting supported modes: {str(e)}")
|
||||
server_output_modes = ["text"] # Fallback to text
|
||||
|
||||
if not are_modalities_compatible(
|
||||
server_output_modes, params.acceptedOutputModes
|
||||
):
|
||||
logger.warning(
|
||||
f"Incompatible modes: server={server_output_modes}, client={params.acceptedOutputModes}"
|
||||
)
|
||||
return SendTaskResponse(
|
||||
id=request.id, error=ContentTypeNotSupportedError()
|
||||
)
|
||||
|
||||
# Create task data
|
||||
task_data = await self._create_task_data(params)
|
||||
|
||||
# Store task in cache
|
||||
await self.redis_cache.set(f"task:{task_id}", task_data)
|
||||
|
||||
# Configure push notifications, if provided
|
||||
if params.pushNotification:
|
||||
await self.redis_cache.set(
|
||||
f"task_notification:{task_id}", params.pushNotification.model_dump()
|
||||
)
|
||||
|
||||
# Execute task SYNCHRONOUSLY instead of in background
|
||||
# This is the key change for A2A compatibility
|
||||
task_data = await self._execute_task(task_data, params)
|
||||
|
||||
# Convert to Task object and return
|
||||
task = Task.model_validate(task_data)
|
||||
return SendTaskResponse(id=request.id, result=task)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing on_send_task: {str(e)}")
|
||||
return SendTaskResponse(id=request.id, error=InternalError(message=str(e)))
|
||||
|
||||
async def on_send_task_subscribe(
|
||||
self, request: SendTaskStreamingRequest
|
||||
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
|
||||
"""
|
||||
Handle request to send a task and subscribe to streaming updates.
|
||||
|
||||
Args:
|
||||
request: The JSON-RPC request to send a task with streaming
|
||||
|
||||
Returns:
|
||||
Stream of events or error response
|
||||
"""
|
||||
logger.info(f"Sending task with streaming {request.params.id}")
|
||||
task_send_params = request.params
|
||||
|
||||
try:
|
||||
# Check output mode compatibility
|
||||
if not are_modalities_compatible(
|
||||
["text", "application/json"], # Default supported modes
|
||||
task_send_params.acceptedOutputModes,
|
||||
):
|
||||
return new_incompatible_types_error(request.id)
|
||||
|
||||
# Create initial task data
|
||||
task_data = await self._create_task_data(task_send_params)
|
||||
|
||||
# Setup SSE consumer
|
||||
sse_queue = await self._setup_sse_consumer(task_send_params.id)
|
||||
|
||||
# Execute task asynchronously (fire and forget)
|
||||
asyncio.create_task(self._execute_task(task_data, task_send_params))
|
||||
|
||||
# Return generator to dequeue events for SSE
|
||||
return self._dequeue_events_for_sse(
|
||||
request.id, task_send_params.id, sse_queue
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting up streaming task: {str(e)}", exc_info=True)
|
||||
return SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
error=InternalError(
|
||||
message=f"Error setting up streaming task: {str(e)}"
|
||||
),
|
||||
)
|
||||
|
||||
async def on_set_task_push_notification(
|
||||
self, request: SetTaskPushNotificationRequest
|
||||
) -> SetTaskPushNotificationResponse:
|
||||
"""
|
||||
Configure push notifications for a task.
|
||||
|
||||
Args:
|
||||
request: The JSON-RPC request to set push notification
|
||||
|
||||
Returns:
|
||||
Response with configuration or error
|
||||
"""
|
||||
logger.info(f"Setting push notification for task {request.params.id}")
|
||||
task_notification_params = request.params
|
||||
|
||||
try:
|
||||
if not self.push_notification_service:
|
||||
logger.warning("Push notifications not supported")
|
||||
return SetTaskPushNotificationResponse(
|
||||
id=request.id, error=PushNotificationNotSupportedError()
|
||||
)
|
||||
|
||||
# Check if task exists
|
||||
task_data = await self.redis_cache.get(
|
||||
f"task:{task_notification_params.id}"
|
||||
)
|
||||
if not task_data:
|
||||
logger.warning(
|
||||
f"Task {task_notification_params.id} not found for setting push notification"
|
||||
)
|
||||
return SetTaskPushNotificationResponse(
|
||||
id=request.id, error=TaskNotFoundError()
|
||||
)
|
||||
|
||||
# Save push notification config
|
||||
config = {
|
||||
"url": task_notification_params.pushNotificationConfig.url,
|
||||
"headers": {}, # Add auth headers if needed
|
||||
}
|
||||
|
||||
await self.redis_cache.set(
|
||||
f"task:{task_notification_params.id}:push", config
|
||||
)
|
||||
|
||||
return SetTaskPushNotificationResponse(
|
||||
id=request.id, result=task_notification_params
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting push notification: {str(e)}", exc_info=True)
|
||||
return SetTaskPushNotificationResponse(
|
||||
id=request.id,
|
||||
error=InternalError(
|
||||
message=f"Error setting push notification: {str(e)}"
|
||||
),
|
||||
)
|
||||
|
||||
async def on_get_task_push_notification(
|
||||
self, request: GetTaskPushNotificationRequest
|
||||
) -> GetTaskPushNotificationResponse:
|
||||
"""
|
||||
Get push notification configuration for a task.
|
||||
|
||||
Args:
|
||||
request: The JSON-RPC request to get push notification config
|
||||
|
||||
Returns:
|
||||
Response with configuration or error
|
||||
"""
|
||||
logger.info(f"Getting push notification for task {request.params.id}")
|
||||
task_params = request.params
|
||||
|
||||
try:
|
||||
if not self.push_notification_service:
|
||||
logger.warning("Push notifications not supported")
|
||||
return GetTaskPushNotificationResponse(
|
||||
id=request.id, error=PushNotificationNotSupportedError()
|
||||
)
|
||||
|
||||
# Check if task exists
|
||||
task_data = await self.redis_cache.get(f"task:{task_params.id}")
|
||||
if not task_data:
|
||||
logger.warning(
|
||||
f"Task {task_params.id} not found for getting push notification"
|
||||
)
|
||||
return GetTaskPushNotificationResponse(
|
||||
id=request.id, error=TaskNotFoundError()
|
||||
)
|
||||
|
||||
# Get push notification config
|
||||
config = await self.redis_cache.get(f"task:{task_params.id}:push")
|
||||
if not config:
|
||||
logger.warning(f"No push notification config for task {task_params.id}")
|
||||
return GetTaskPushNotificationResponse(
|
||||
id=request.id,
|
||||
error=InternalError(
|
||||
message="No push notification configuration found"
|
||||
),
|
||||
)
|
||||
|
||||
result = TaskPushNotificationConfig(
|
||||
id=task_params.id,
|
||||
pushNotificationConfig=PushNotificationConfig(
|
||||
url=config.get("url"), token=None, authentication=None
|
||||
),
|
||||
)
|
||||
|
||||
return GetTaskPushNotificationResponse(id=request.id, result=result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting push notification: {str(e)}", exc_info=True)
|
||||
return GetTaskPushNotificationResponse(
|
||||
id=request.id,
|
||||
error=InternalError(
|
||||
message=f"Error getting push notification: {str(e)}"
|
||||
),
|
||||
)
|
||||
|
||||
async def on_resubscribe_to_task(
|
||||
self, request: TaskResubscriptionRequest
|
||||
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
|
||||
"""
|
||||
Resubscribe to a task's streaming events.
|
||||
|
||||
Args:
|
||||
request: The JSON-RPC request to resubscribe
|
||||
|
||||
Returns:
|
||||
Stream of events or error response
|
||||
"""
|
||||
logger.info(f"Resubscribing to task {request.params.id}")
|
||||
task_params = request.params
|
||||
|
||||
try:
|
||||
# Check if task exists
|
||||
task_data = await self.redis_cache.get(f"task:{task_params.id}")
|
||||
if not task_data:
|
||||
logger.warning(f"Task {task_params.id} not found for resubscription")
|
||||
return JSONRPCResponse(id=request.id, error=TaskNotFoundError())
|
||||
|
||||
# Setup SSE consumer with resubscribe flag
|
||||
try:
|
||||
sse_queue = await self._setup_sse_consumer(
|
||||
task_params.id, is_resubscribe=True
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Task {task_params.id} not available for resubscription"
|
||||
)
|
||||
return JSONRPCResponse(
|
||||
id=request.id,
|
||||
error=InternalError(
|
||||
message="Task not available for resubscription"
|
||||
),
|
||||
)
|
||||
|
||||
# Send initial status update to the new subscriber
|
||||
status = task_data.get("status", {})
|
||||
final = status.get("state") in [
|
||||
TaskState.COMPLETED,
|
||||
TaskState.FAILED,
|
||||
TaskState.CANCELED,
|
||||
]
|
||||
|
||||
# Convert to TaskStatus object
|
||||
task_status = TaskStatus(
|
||||
state=status.get("state", TaskState.UNKNOWN),
|
||||
timestamp=datetime.fromisoformat(
|
||||
status.get("timestamp", datetime.now().isoformat())
|
||||
),
|
||||
message=status.get("message"),
|
||||
)
|
||||
|
||||
# Publish to the specific queue
|
||||
await sse_queue.put(
|
||||
TaskStatusUpdateEvent(
|
||||
id=task_params.id, status=task_status, final=final
|
||||
)
|
||||
)
|
||||
|
||||
# Return generator to dequeue events for SSE
|
||||
return self._dequeue_events_for_sse(request.id, task_params.id, sse_queue)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error resubscribing to task: {str(e)}", exc_info=True)
|
||||
return JSONRPCResponse(
|
||||
id=request.id,
|
||||
error=InternalError(message=f"Error resubscribing to task: {str(e)}"),
|
||||
)
|
||||
|
||||
async def _create_task_data(self, params: TaskSendParams) -> Dict[str, Any]:
|
||||
"""
|
||||
Create initial task data structure.
|
||||
|
||||
Args:
|
||||
params: Task send parameters
|
||||
|
||||
Returns:
|
||||
Task data dictionary
|
||||
"""
|
||||
# Create task with initial status
|
||||
task_data = {
|
||||
"id": params.id,
|
||||
"sessionId": params.sessionId
|
||||
or str(uuid.uuid4()), # Preservar sessionId quando fornecido
|
||||
"status": {
|
||||
"state": TaskState.SUBMITTED,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": None,
|
||||
"error": None,
|
||||
},
|
||||
"artifacts": [],
|
||||
"history": [params.message.model_dump()], # Apenas mensagem do usuário
|
||||
"metadata": params.metadata or {},
|
||||
}
|
||||
|
||||
# Save task data
|
||||
await self.redis_cache.set(f"task:{params.id}", task_data)
|
||||
|
||||
return task_data
|
||||
|
||||
async def _execute_task(
|
||||
self, task: Dict[str, Any], params: TaskSendParams
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute a task using the agent adapter.
|
||||
|
||||
This function is responsible for executing the task by the agent,
|
||||
updating its status as progress is made.
|
||||
|
||||
Args:
|
||||
task: Task data to be executed
|
||||
params: Send task parameters
|
||||
|
||||
Returns:
|
||||
Updated task data with completed status and response
|
||||
"""
|
||||
task_id = task["id"]
|
||||
agent_id = params.agentId
|
||||
message_text = ""
|
||||
|
||||
# Extract the text from the message
|
||||
if params.message and params.message.parts:
|
||||
for part in params.message.parts:
|
||||
if part.type == "text":
|
||||
message_text += part.text
|
||||
|
||||
if not message_text:
|
||||
await self._update_task_status_without_history(
|
||||
task_id, TaskState.FAILED, "Message does not contain text", final=True
|
||||
)
|
||||
# Return the updated task data
|
||||
return await self.redis_cache.get(f"task:{task_id}")
|
||||
|
||||
# Check if it is an ongoing execution
|
||||
task_status = task.get("status", {})
|
||||
if task_status.get("state") in [TaskState.WORKING, TaskState.COMPLETED]:
|
||||
logger.info(f"Task {task_id} is already in execution or completed")
|
||||
# Return the current task data
|
||||
return await self.redis_cache.get(f"task:{task_id}")
|
||||
|
||||
try:
|
||||
# Update to "working" state - NÃO adicionar ao histórico
|
||||
await self._update_task_status_without_history(
|
||||
task_id, TaskState.WORKING, "Processing request"
|
||||
)
|
||||
|
||||
# Execute the agent
|
||||
if self.agent_runner:
|
||||
response = await self.agent_runner.run_agent(
|
||||
agent_id=agent_id,
|
||||
message=message_text,
|
||||
session_id=params.sessionId, # Usar o sessionId da requisição
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
# Process the agent's response
|
||||
if response and isinstance(response, dict):
|
||||
# Extract text from the response
|
||||
response_text = response.get("content", "")
|
||||
if not response_text and "message" in response:
|
||||
message = response.get("message", {})
|
||||
parts = message.get("parts", [])
|
||||
for part in parts:
|
||||
if part.get("type") == "text":
|
||||
response_text += part.get("text", "")
|
||||
|
||||
# Build the final agent message
|
||||
if response_text:
|
||||
# Atualizar o histórico com a mensagem do usuário
|
||||
await self._update_task_history(task_id, params.message)
|
||||
|
||||
# Create an artifact for the response in Google A2A format
|
||||
artifact = {
|
||||
"parts": [{"type": "text", "text": response_text}],
|
||||
"index": 0,
|
||||
}
|
||||
|
||||
# Add the artifact to the task
|
||||
await self._add_task_artifact(task_id, artifact)
|
||||
|
||||
# Update the task status to completed (sem adicionar ao histórico)
|
||||
await self._update_task_status_without_history(
|
||||
task_id, TaskState.COMPLETED, response_text, final=True
|
||||
)
|
||||
else:
|
||||
await self._update_task_status_without_history(
|
||||
task_id,
|
||||
TaskState.FAILED,
|
||||
"The agent did not return a valid response",
|
||||
final=True,
|
||||
)
|
||||
else:
|
||||
await self._update_task_status_without_history(
|
||||
task_id,
|
||||
TaskState.FAILED,
|
||||
"Invalid agent response",
|
||||
final=True,
|
||||
)
|
||||
else:
|
||||
await self._update_task_status_without_history(
|
||||
task_id,
|
||||
TaskState.FAILED,
|
||||
"Agent adapter not configured",
|
||||
final=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing task {task_id}: {str(e)}")
|
||||
await self._update_task_status_without_history(
|
||||
task_id, TaskState.FAILED, f"Error processing: {str(e)}", final=True
|
||||
)
|
||||
|
||||
# Return the updated task data
|
||||
return await self.redis_cache.get(f"task:{task_id}")
|
||||
|
||||
async def _update_task_history(self, task_id: str, message) -> None:
|
||||
"""
|
||||
Atualiza o histórico da tarefa incluindo apenas mensagens do usuário.
|
||||
|
||||
Args:
|
||||
task_id: ID da tarefa
|
||||
message: Mensagem do usuário para adicionar ao histórico
|
||||
"""
|
||||
if not message:
|
||||
return
|
||||
|
||||
# Obter dados da tarefa atual
|
||||
task_data = await self.redis_cache.get(f"task:{task_id}")
|
||||
if not task_data:
|
||||
logger.warning(f"Task {task_id} not found for history update")
|
||||
return
|
||||
|
||||
# Garantir que há um campo de histórico
|
||||
if "history" not in task_data:
|
||||
task_data["history"] = []
|
||||
|
||||
# Verificar se a mensagem já existe no histórico para evitar duplicação
|
||||
user_message = (
|
||||
message.model_dump() if hasattr(message, "model_dump") else message
|
||||
)
|
||||
message_exists = False
|
||||
|
||||
for msg in task_data["history"]:
|
||||
if self._compare_messages(msg, user_message):
|
||||
message_exists = True
|
||||
break
|
||||
|
||||
# Adicionar mensagem se não existir
|
||||
if not message_exists:
|
||||
task_data["history"].append(user_message)
|
||||
logger.info(f"Added new user message to history for task {task_id}")
|
||||
|
||||
# Salvar tarefa atualizada
|
||||
await self.redis_cache.set(f"task:{task_id}", task_data)
|
||||
|
||||
# Método auxiliar para comparar mensagens e evitar duplicação no histórico
|
||||
def _compare_messages(self, msg1: Dict, msg2: Dict) -> bool:
|
||||
"""Compara duas mensagens para verificar se são essencialmente iguais."""
|
||||
if msg1.get("role") != msg2.get("role"):
|
||||
return False
|
||||
|
||||
parts1 = msg1.get("parts", [])
|
||||
parts2 = msg2.get("parts", [])
|
||||
|
||||
if len(parts1) != len(parts2):
|
||||
return False
|
||||
|
||||
for i in range(len(parts1)):
|
||||
if parts1[i].get("type") != parts2[i].get("type"):
|
||||
return False
|
||||
if parts1[i].get("text") != parts2[i].get("text"):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _update_task_status_without_history(
|
||||
self, task_id: str, state: TaskState, message_text: str, final: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Update the status of a task without changing the history.
|
||||
|
||||
Args:
|
||||
task_id: ID of the task to be updated
|
||||
state: New task state
|
||||
message_text: Text of the message associated with the status
|
||||
final: Indicates if this is the final status of the task
|
||||
"""
|
||||
try:
|
||||
# Get current task data
|
||||
task_data = await self.redis_cache.get(f"task:{task_id}")
|
||||
if not task_data:
|
||||
logger.warning(f"Unable to update status: task {task_id} not found")
|
||||
return
|
||||
|
||||
# Create status object with the message
|
||||
agent_message = Message(
|
||||
role="agent",
|
||||
parts=[TextPart(text=message_text)],
|
||||
)
|
||||
|
||||
status = TaskStatus(
|
||||
state=state, message=agent_message, timestamp=datetime.now()
|
||||
)
|
||||
|
||||
# Update the status in the task
|
||||
task_data["status"] = status.model_dump(exclude_none=True)
|
||||
|
||||
# Store the updated task
|
||||
await self.redis_cache.set(f"task:{task_id}", task_data)
|
||||
|
||||
# Create status update event
|
||||
status_event = TaskStatusUpdateEvent(id=task_id, status=status, final=final)
|
||||
|
||||
# Publish status update
|
||||
await self._publish_task_update(task_id, status_event)
|
||||
|
||||
# Send push notification, if configured
|
||||
if final or state in [
|
||||
TaskState.FAILED,
|
||||
TaskState.COMPLETED,
|
||||
TaskState.CANCELED,
|
||||
]:
|
||||
await self._send_push_notification_for_task(
|
||||
task_id=task_id, state=state, message_text=message_text
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating task status {task_id}: {str(e)}")
|
||||
|
||||
async def _add_task_artifact(self, task_id: str, artifact) -> None:
|
||||
"""
|
||||
Add an artifact to a task and publish the update.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
artifact: Artifact to add (dict no formato do Google)
|
||||
"""
|
||||
logger.info(f"Adding artifact to task {task_id}")
|
||||
|
||||
# Update task data
|
||||
task_data = await self.redis_cache.get(f"task:{task_id}")
|
||||
if task_data:
|
||||
if "artifacts" not in task_data:
|
||||
task_data["artifacts"] = []
|
||||
|
||||
# Adicionar o artefato sem substituir os existentes
|
||||
task_data["artifacts"].append(artifact)
|
||||
await self.redis_cache.set(f"task:{task_id}", task_data)
|
||||
|
||||
# Criar um artefato do tipo Artifact para o evento
|
||||
artifact_obj = Artifact(
|
||||
parts=[
|
||||
TextPart(text=part.get("text", ""))
|
||||
for part in artifact.get("parts", [])
|
||||
],
|
||||
index=artifact.get("index", 0),
|
||||
)
|
||||
|
||||
# Create artifact update event
|
||||
event = TaskArtifactUpdateEvent(id=task_id, artifact=artifact_obj)
|
||||
|
||||
# Publish event
|
||||
await self._publish_task_update(task_id, event)
|
||||
|
||||
async def _publish_task_update(
|
||||
self, task_id: str, event: Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent]
|
||||
) -> None:
|
||||
"""
|
||||
Publish a task update event to all subscribers.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
event: Event to publish
|
||||
"""
|
||||
async with self.subscriber_lock:
|
||||
if task_id not in self.task_sse_subscribers:
|
||||
return
|
||||
|
||||
subscribers = self.task_sse_subscribers[task_id]
|
||||
for subscriber in subscribers:
|
||||
try:
|
||||
await subscriber.put(event)
|
||||
except Exception as e:
|
||||
logger.error(f"Error publishing event to subscriber: {str(e)}")
|
||||
|
||||
async def _send_push_notification_for_task(
|
||||
self,
|
||||
task_id: str,
|
||||
state: str,
|
||||
message_text: str = None,
|
||||
system_message: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send push notification for a task if configured.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
state: Task state
|
||||
message_text: Optional message text
|
||||
system_message: Optional system message
|
||||
"""
|
||||
if not self.push_notification_service:
|
||||
return
|
||||
|
||||
try:
|
||||
# Get push notification config
|
||||
config = await self.redis_cache.get(f"task:{task_id}:push")
|
||||
if not config:
|
||||
return
|
||||
|
||||
# Create message if provided
|
||||
message = None
|
||||
if message_text:
|
||||
message = {
|
||||
"role": "agent",
|
||||
"parts": [{"type": "text", "text": message_text}],
|
||||
}
|
||||
elif system_message:
|
||||
# We use 'agent' instead of 'system' since Message only accepts 'user' or 'agent'
|
||||
message = {
|
||||
"role": "agent",
|
||||
"parts": [{"type": "text", "text": system_message}],
|
||||
}
|
||||
|
||||
# Send notification
|
||||
await self.push_notification_service.send_notification(
|
||||
url=config["url"],
|
||||
task_id=task_id,
|
||||
state=state,
|
||||
message=message,
|
||||
headers=config.get("headers", {}),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error sending push notification for task {task_id}: {str(e)}"
|
||||
)
|
||||
|
||||
async def _setup_sse_consumer(
|
||||
self, task_id: str, is_resubscribe: bool = False
|
||||
) -> asyncio.Queue:
|
||||
"""
|
||||
Set up an SSE consumer queue for a task.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
is_resubscribe: Whether this is a resubscription
|
||||
|
||||
Returns:
|
||||
Queue for events
|
||||
|
||||
Raises:
|
||||
ValueError: If resubscribing to non-existent task
|
||||
"""
|
||||
async with self.subscriber_lock:
|
||||
if task_id not in self.task_sse_subscribers:
|
||||
if is_resubscribe:
|
||||
raise ValueError("Task not found for resubscription")
|
||||
self.task_sse_subscribers[task_id] = []
|
||||
|
||||
queue = asyncio.Queue()
|
||||
self.task_sse_subscribers[task_id].append(queue)
|
||||
return queue
|
||||
|
||||
async def _dequeue_events_for_sse(
|
||||
self, request_id: str, task_id: str, event_queue: asyncio.Queue
|
||||
) -> AsyncIterable[SendTaskStreamingResponse]:
|
||||
"""
|
||||
Dequeue and yield events for SSE streaming.
|
||||
|
||||
Args:
|
||||
request_id: Request ID
|
||||
task_id: Task ID
|
||||
event_queue: Queue for events
|
||||
|
||||
Yields:
|
||||
SSE events wrapped in SendTaskStreamingResponse
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
event = await event_queue.get()
|
||||
|
||||
if isinstance(event, JSONRPCError):
|
||||
yield SendTaskStreamingResponse(id=request_id, error=event)
|
||||
break
|
||||
|
||||
yield SendTaskStreamingResponse(id=request_id, result=event)
|
||||
|
||||
# Check if this is the final event
|
||||
is_final = False
|
||||
if hasattr(event, "final") and event.final:
|
||||
is_final = True
|
||||
|
||||
if is_final:
|
||||
break
|
||||
finally:
|
||||
# Clean up the subscription when done
|
||||
async with self.subscriber_lock:
|
||||
if task_id in self.task_sse_subscribers:
|
||||
try:
|
||||
self.task_sse_subscribers[task_id].remove(event_queue)
|
||||
# Remove the task from the dict if no more subscribers
|
||||
if not self.task_sse_subscribers[task_id]:
|
||||
del self.task_sse_subscribers[task_id]
|
||||
except ValueError:
|
||||
pass # Queue might have been removed already
|
@ -1,265 +0,0 @@
|
||||
"""
|
||||
Push Notification Authentication Service.
|
||||
|
||||
This service implements JWT authentication for A2A push notifications,
|
||||
allowing agents to send authenticated notifications and clients to verify
|
||||
the authenticity of received notifications.
|
||||
"""
|
||||
|
||||
from jwcrypto import jwk
|
||||
import uuid
|
||||
import time
|
||||
import json
|
||||
import hashlib
|
||||
import httpx
|
||||
import logging
|
||||
import jwt
|
||||
from jwt import PyJWK, PyJWKClient
|
||||
from fastapi import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from typing import Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
AUTH_HEADER_PREFIX = "Bearer "
|
||||
|
||||
|
||||
class PushNotificationAuth:
|
||||
"""
|
||||
Base class for push notification authentication.
|
||||
|
||||
Contains common methods used by both the sender and the receiver
|
||||
of push notifications.
|
||||
"""
|
||||
|
||||
def _calculate_request_body_sha256(self, data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Calculates the SHA256 hash of the request body.
|
||||
|
||||
This logic needs to be the same for the agent that signs the payload
|
||||
and for the client that verifies it.
|
||||
|
||||
Args:
|
||||
data: Request body data
|
||||
|
||||
Returns:
|
||||
SHA256 hash as a hexadecimal string
|
||||
"""
|
||||
body_str = json.dumps(
|
||||
data,
|
||||
ensure_ascii=False,
|
||||
allow_nan=False,
|
||||
indent=None,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
return hashlib.sha256(body_str.encode()).hexdigest()
|
||||
|
||||
|
||||
class PushNotificationSenderAuth(PushNotificationAuth):
|
||||
"""
|
||||
Authentication for the push notification sender.
|
||||
|
||||
This class is used by the A2A server to authenticate notifications
|
||||
sent to callback URLs of clients.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the push notification sender authentication service.
|
||||
"""
|
||||
self.public_keys = []
|
||||
self.private_key_jwk = None
|
||||
|
||||
@staticmethod
|
||||
async def verify_push_notification_url(url: str) -> bool:
|
||||
"""
|
||||
Verifies if a push notification URL is valid and responds correctly.
|
||||
|
||||
Sends a validation token and verifies if the response contains the same token.
|
||||
|
||||
Args:
|
||||
url: URL to be verified
|
||||
|
||||
Returns:
|
||||
True if the URL is verified successfully, False otherwise
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
try:
|
||||
validation_token = str(uuid.uuid4())
|
||||
response = await client.get(
|
||||
url, params={"validationToken": validation_token}
|
||||
)
|
||||
response.raise_for_status()
|
||||
is_verified = response.text == validation_token
|
||||
|
||||
logger.info(f"Push notification URL verified: {url} => {is_verified}")
|
||||
return is_verified
|
||||
except Exception as e:
|
||||
logger.warning(f"Error verifying push notification URL {url}: {e}")
|
||||
|
||||
return False
|
||||
|
||||
def generate_jwk(self):
|
||||
"""
|
||||
Generates a new JWK pair for signing.
|
||||
|
||||
The key pair is used to sign push notifications.
|
||||
The public key is available via the JWKS endpoint.
|
||||
"""
|
||||
key = jwk.JWK.generate(kty="RSA", size=2048, kid=str(uuid.uuid4()), use="sig")
|
||||
self.public_keys.append(key.export_public(as_dict=True))
|
||||
self.private_key_jwk = PyJWK.from_json(key.export_private())
|
||||
|
||||
def handle_jwks_endpoint(self, _request: Request) -> JSONResponse:
|
||||
"""
|
||||
Handles the JWKS endpoint to allow clients to obtain the public keys.
|
||||
|
||||
Args:
|
||||
_request: HTTP request
|
||||
|
||||
Returns:
|
||||
JSON response with the public keys
|
||||
"""
|
||||
return JSONResponse({"keys": self.public_keys})
|
||||
|
||||
def _generate_jwt(self, data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Generates a JWT token by signing the SHA256 hash of the payload and the timestamp.
|
||||
|
||||
The payload is signed with the private key to ensure integrity.
|
||||
The timestamp (iat) prevents replay attacks.
|
||||
|
||||
Args:
|
||||
data: Payload data
|
||||
|
||||
Returns:
|
||||
Signed JWT token
|
||||
"""
|
||||
iat = int(time.time())
|
||||
|
||||
return jwt.encode(
|
||||
{
|
||||
"iat": iat,
|
||||
"request_body_sha256": self._calculate_request_body_sha256(data),
|
||||
},
|
||||
key=self.private_key_jwk.key,
|
||||
headers={"kid": self.private_key_jwk.key_id},
|
||||
algorithm="RS256",
|
||||
)
|
||||
|
||||
async def send_push_notification(self, url: str, data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Sends an authenticated push notification to the specified URL.
|
||||
|
||||
Args:
|
||||
url: URL to send the notification
|
||||
data: Notification data
|
||||
|
||||
Returns:
|
||||
True if the notification was sent successfully, False otherwise
|
||||
"""
|
||||
if not self.private_key_jwk:
|
||||
logger.error(
|
||||
"Attempt to send push notification without generating JWK keys"
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
jwt_token = self._generate_jwt(data)
|
||||
headers = {"Authorization": f"Bearer {jwt_token}"}
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.post(url, json=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
logger.info(f"Push notification sent to URL: {url}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending push notification to URL {url}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class PushNotificationReceiverAuth(PushNotificationAuth):
|
||||
"""
|
||||
Authentication for the push notification receiver.
|
||||
|
||||
This class is used by clients to verify the authenticity
|
||||
of notifications received from the A2A server.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the push notification receiver authentication service.
|
||||
"""
|
||||
self.jwks_client = None
|
||||
|
||||
async def load_jwks(self, jwks_url: str):
|
||||
"""
|
||||
Loads the public JWKS keys from a URL.
|
||||
|
||||
Args:
|
||||
jwks_url: URL of the JWKS endpoint
|
||||
"""
|
||||
self.jwks_client = PyJWKClient(jwks_url)
|
||||
|
||||
async def verify_push_notification(self, request: Request) -> bool:
|
||||
"""
|
||||
Verifies the authenticity of a push notification.
|
||||
|
||||
Args:
|
||||
request: HTTP request containing the notification
|
||||
|
||||
Returns:
|
||||
True if the notification is authentic, False otherwise
|
||||
|
||||
Raises:
|
||||
ValueError: If the token is invalid or expired
|
||||
"""
|
||||
if not self.jwks_client:
|
||||
logger.error("Attempt to verify notification without loading JWKS keys")
|
||||
return False
|
||||
|
||||
# Verify authentication header
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX):
|
||||
logger.warning("Invalid authorization header")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Extract and verify token
|
||||
token = auth_header[len(AUTH_HEADER_PREFIX) :]
|
||||
signing_key = self.jwks_client.get_signing_key_from_jwt(token)
|
||||
|
||||
# Decode token
|
||||
decode_token = jwt.decode(
|
||||
token,
|
||||
signing_key.key,
|
||||
options={"require": ["iat", "request_body_sha256"]},
|
||||
algorithms=["RS256"],
|
||||
)
|
||||
|
||||
# Verify request body integrity
|
||||
body_data = await request.json()
|
||||
actual_body_sha256 = self._calculate_request_body_sha256(body_data)
|
||||
if actual_body_sha256 != decode_token["request_body_sha256"]:
|
||||
# The payload signature does not match the hash in the signed token
|
||||
logger.warning("Request body hash does not match the token")
|
||||
raise ValueError("Invalid request body")
|
||||
|
||||
# Verify token age (maximum 5 minutes)
|
||||
if time.time() - decode_token["iat"] > 60 * 5:
|
||||
# Do not allow notifications older than 5 minutes
|
||||
# This prevents replay attacks
|
||||
logger.warning("Token expired")
|
||||
raise ValueError("Token expired")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying push notification: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# Global instance of the push notification sender authentication service
|
||||
push_notification_auth = PushNotificationSenderAuth()
|
||||
|
||||
# Generate JWK keys on initialization
|
||||
push_notification_auth.generate_jwk()
|
@ -1,99 +0,0 @@
|
||||
import aiohttp
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
import asyncio
|
||||
|
||||
from src.services.push_notification_auth_service import push_notification_auth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PushNotificationService:
|
||||
def __init__(self):
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
async def send_notification(
|
||||
self,
|
||||
url: str,
|
||||
task_id: str,
|
||||
state: str,
|
||||
message: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
use_jwt_auth: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Send a push notification to the specified URL.
|
||||
Implements exponential backoff retry.
|
||||
|
||||
Args:
|
||||
url: URL to send the notification
|
||||
task_id: Task ID
|
||||
state: Task state
|
||||
message: Optional message
|
||||
headers: Optional HTTP headers
|
||||
max_retries: Maximum number of retries
|
||||
retry_delay: Initial delay between retries
|
||||
use_jwt_auth: Whether to use JWT authentication
|
||||
|
||||
Returns:
|
||||
True if the notification was sent successfully, False otherwise
|
||||
"""
|
||||
payload = {
|
||||
"taskId": task_id,
|
||||
"state": state,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": message,
|
||||
}
|
||||
|
||||
# First URL verification
|
||||
if use_jwt_auth:
|
||||
is_url_valid = await push_notification_auth.verify_push_notification_url(
|
||||
url
|
||||
)
|
||||
if not is_url_valid:
|
||||
logger.warning(f"Invalid push notification URL: {url}")
|
||||
return False
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
if use_jwt_auth:
|
||||
# Use JWT authentication
|
||||
success = await push_notification_auth.send_push_notification(
|
||||
url, payload
|
||||
)
|
||||
if success:
|
||||
return True
|
||||
else:
|
||||
# Legacy method without JWT authentication
|
||||
async with self.session.post(
|
||||
url, json=payload, headers=headers or {}, timeout=10
|
||||
) as response:
|
||||
if response.status in (200, 201, 202, 204):
|
||||
logger.info(f"Push notification sent to URL: {url}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to send push notification with status {response.status}. "
|
||||
f"Attempt {attempt + 1}/{max_retries}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error sending push notification: {str(e)}. "
|
||||
f"Attempt {attempt + 1}/{max_retries}"
|
||||
)
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(retry_delay * (2**attempt))
|
||||
|
||||
return False
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session"""
|
||||
await self.session.close()
|
||||
|
||||
|
||||
# Global instance of the push notification service
|
||||
push_notification_service = PushNotificationService()
|
@ -1,555 +0,0 @@
|
||||
"""
|
||||
Cache Redis service for the A2A protocol.
|
||||
|
||||
This service provides an interface for storing and retrieving data related to tasks,
|
||||
push notification configurations, and other A2A-related data.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
import asyncio
|
||||
import redis.asyncio as aioredis
|
||||
from src.config.redis import get_redis_config
|
||||
import threading
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _InMemoryCacheFallback:
|
||||
"""
|
||||
Fallback in-memory cache implementation for when Redis is not available.
|
||||
|
||||
This should only be used for development or testing environments.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
"""Singleton implementation."""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize cache storage."""
|
||||
if not getattr(self, "_initialized", False):
|
||||
with self._lock:
|
||||
if not getattr(self, "_initialized", False):
|
||||
self._data = {}
|
||||
self._ttls = {}
|
||||
self._hash_data = {}
|
||||
self._list_data = {}
|
||||
self._data_lock = threading.Lock()
|
||||
self._initialized = True
|
||||
logger.warning(
|
||||
"Initializing in-memory cache fallback (not for production)"
|
||||
)
|
||||
|
||||
async def set(self, key, value, ex=None):
|
||||
"""Set a key with optional expiration."""
|
||||
with self._data_lock:
|
||||
self._data[key] = value
|
||||
if ex is not None:
|
||||
self._ttls[key] = time.time() + ex
|
||||
elif key in self._ttls:
|
||||
del self._ttls[key]
|
||||
return True
|
||||
|
||||
async def setex(self, key, ex, value):
|
||||
"""Set a key with expiration."""
|
||||
return await self.set(key, value, ex)
|
||||
|
||||
async def get(self, key):
|
||||
"""Get a key value."""
|
||||
with self._data_lock:
|
||||
# Check if expired
|
||||
if key in self._ttls and time.time() > self._ttls[key]:
|
||||
del self._data[key]
|
||||
del self._ttls[key]
|
||||
return None
|
||||
return self._data.get(key)
|
||||
|
||||
async def delete(self, key):
|
||||
"""Delete a key."""
|
||||
with self._data_lock:
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
if key in self._ttls:
|
||||
del self._ttls[key]
|
||||
return 1
|
||||
return 0
|
||||
|
||||
async def exists(self, key):
|
||||
"""Check if key exists."""
|
||||
with self._data_lock:
|
||||
if key in self._ttls and time.time() > self._ttls[key]:
|
||||
del self._data[key]
|
||||
del self._ttls[key]
|
||||
return 0
|
||||
return 1 if key in self._data else 0
|
||||
|
||||
async def hset(self, key, field, value):
|
||||
"""Set a hash field."""
|
||||
with self._data_lock:
|
||||
if key not in self._hash_data:
|
||||
self._hash_data[key] = {}
|
||||
self._hash_data[key][field] = value
|
||||
return 1
|
||||
|
||||
async def hget(self, key, field):
|
||||
"""Get a hash field."""
|
||||
with self._data_lock:
|
||||
if key not in self._hash_data:
|
||||
return None
|
||||
return self._hash_data[key].get(field)
|
||||
|
||||
async def hdel(self, key, field):
|
||||
"""Delete a hash field."""
|
||||
with self._data_lock:
|
||||
if key in self._hash_data and field in self._hash_data[key]:
|
||||
del self._hash_data[key][field]
|
||||
return 1
|
||||
return 0
|
||||
|
||||
async def hgetall(self, key):
|
||||
"""Get all hash fields."""
|
||||
with self._data_lock:
|
||||
if key not in self._hash_data:
|
||||
return {}
|
||||
return dict(self._hash_data[key])
|
||||
|
||||
async def rpush(self, key, value):
|
||||
"""Add to a list."""
|
||||
with self._data_lock:
|
||||
if key not in self._list_data:
|
||||
self._list_data[key] = []
|
||||
self._list_data[key].append(value)
|
||||
return len(self._list_data[key])
|
||||
|
||||
async def lrange(self, key, start, end):
|
||||
"""Get range from list."""
|
||||
with self._data_lock:
|
||||
if key not in self._list_data:
|
||||
return []
|
||||
|
||||
# Handle negative indices
|
||||
if end < 0:
|
||||
end = len(self._list_data[key]) + end + 1
|
||||
|
||||
return self._list_data[key][start:end]
|
||||
|
||||
async def expire(self, key, seconds):
|
||||
"""Set expiration on key."""
|
||||
with self._data_lock:
|
||||
if key in self._data:
|
||||
self._ttls[key] = time.time() + seconds
|
||||
return 1
|
||||
return 0
|
||||
|
||||
async def flushdb(self):
|
||||
"""Clear all data."""
|
||||
with self._data_lock:
|
||||
self._data.clear()
|
||||
self._ttls.clear()
|
||||
self._hash_data.clear()
|
||||
self._list_data.clear()
|
||||
return True
|
||||
|
||||
async def keys(self, pattern="*"):
|
||||
"""Get keys matching pattern."""
|
||||
with self._data_lock:
|
||||
# Clean expired keys
|
||||
now = time.time()
|
||||
expired_keys = [k for k, exp in self._ttls.items() if now > exp]
|
||||
for k in expired_keys:
|
||||
if k in self._data:
|
||||
del self._data[k]
|
||||
del self._ttls[k]
|
||||
|
||||
# Simple pattern matching
|
||||
result = []
|
||||
if pattern == "*":
|
||||
result = list(self._data.keys())
|
||||
elif pattern.endswith("*"):
|
||||
prefix = pattern[:-1]
|
||||
result = [k for k in self._data.keys() if k.startswith(prefix)]
|
||||
elif pattern.startswith("*"):
|
||||
suffix = pattern[1:]
|
||||
result = [k for k in self._data.keys() if k.endswith(suffix)]
|
||||
else:
|
||||
if pattern in self._data:
|
||||
result = [pattern]
|
||||
|
||||
return result
|
||||
|
||||
async def ping(self):
|
||||
"""Test connection."""
|
||||
return True
|
||||
|
||||
|
||||
class RedisCacheService:
|
||||
"""
|
||||
Cache service using Redis for storing A2A-related data.
|
||||
|
||||
This implementation uses a real Redis connection for distributed caching
|
||||
and data persistence across multiple instances.
|
||||
|
||||
If Redis is not available, falls back to an in-memory implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, redis_url: Optional[str] = None):
|
||||
"""
|
||||
Initialize the Redis cache service.
|
||||
|
||||
Args:
|
||||
redis_url: Redis server URL (optional, defaults to config value)
|
||||
"""
|
||||
if redis_url:
|
||||
self._redis_url = redis_url
|
||||
else:
|
||||
# Construir URL a partir dos componentes de configuração
|
||||
config = get_redis_config()
|
||||
protocol = "rediss" if config.get("ssl", False) else "redis"
|
||||
auth = f":{config['password']}@" if config.get("password") else ""
|
||||
self._redis_url = (
|
||||
f"{protocol}://{auth}{config['host']}:{config['port']}/{config['db']}"
|
||||
)
|
||||
|
||||
self._redis = None
|
||||
self._in_memory_mode = False
|
||||
self._connecting = False
|
||||
self._connection_lock = asyncio.Lock()
|
||||
logger.info(f"Initializing RedisCacheService with URL: {self._redis_url}")
|
||||
|
||||
async def _get_redis(self):
|
||||
"""
|
||||
Get a Redis connection, creating it if necessary.
|
||||
Falls back to in-memory implementation if Redis is not available.
|
||||
|
||||
Returns:
|
||||
Redis connection or in-memory fallback
|
||||
"""
|
||||
if self._redis is not None:
|
||||
return self._redis
|
||||
|
||||
async with self._connection_lock:
|
||||
if self._redis is None and not self._connecting:
|
||||
try:
|
||||
self._connecting = True
|
||||
logger.info(f"Connecting to Redis at {self._redis_url}")
|
||||
self._redis = aioredis.from_url(
|
||||
self._redis_url, encoding="utf-8", decode_responses=True
|
||||
)
|
||||
# Teste de conexão
|
||||
await self._redis.ping()
|
||||
logger.info("Redis connection successful")
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to Redis: {str(e)}")
|
||||
logger.warning(
|
||||
"Falling back to in-memory cache (not suitable for production)"
|
||||
)
|
||||
self._redis = _InMemoryCacheFallback()
|
||||
self._in_memory_mode = True
|
||||
finally:
|
||||
self._connecting = False
|
||||
|
||||
return self._redis
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
|
||||
"""
|
||||
Store a value in the cache.
|
||||
|
||||
Args:
|
||||
key: Key for the value
|
||||
value: Value to store
|
||||
ttl: Time to live in seconds (optional)
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
|
||||
# Convert dict/list to JSON string
|
||||
if isinstance(value, (dict, list)):
|
||||
value = json.dumps(value)
|
||||
|
||||
if ttl:
|
||||
await redis.setex(key, ttl, value)
|
||||
else:
|
||||
await redis.set(key, value)
|
||||
|
||||
logger.debug(f"Set cache key: {key}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting Redis key {key}: {str(e)}")
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Retrieve a value from the cache.
|
||||
|
||||
Args:
|
||||
key: Key for the value to retrieve
|
||||
|
||||
Returns:
|
||||
The stored value or None if not found
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
value = await redis.get(key)
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Try to parse as JSON
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
# Return as-is if not JSON
|
||||
return value
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Redis key {key}: {str(e)}")
|
||||
return None
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Remove a value from the cache.
|
||||
|
||||
Args:
|
||||
key: Key for the value to remove
|
||||
|
||||
Returns:
|
||||
True if the value was removed, False if it didn't exist
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
result = await redis.delete(key)
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting Redis key {key}: {str(e)}")
|
||||
return False
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if a key exists in the cache.
|
||||
|
||||
Args:
|
||||
key: Key to check
|
||||
|
||||
Returns:
|
||||
True if the key exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
return await redis.exists(key) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking Redis key {key}: {str(e)}")
|
||||
return False
|
||||
|
||||
async def set_hash(self, key: str, field: str, value: Any) -> None:
|
||||
"""
|
||||
Store a value in a hash.
|
||||
|
||||
Args:
|
||||
key: Hash key
|
||||
field: Hash field
|
||||
value: Value to store
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
|
||||
# Convert dict/list to JSON string
|
||||
if isinstance(value, (dict, list)):
|
||||
value = json.dumps(value)
|
||||
|
||||
await redis.hset(key, field, value)
|
||||
logger.debug(f"Set hash field: {key}:{field}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting Redis hash {key}:{field}: {str(e)}")
|
||||
|
||||
async def get_hash(self, key: str, field: str) -> Optional[Any]:
|
||||
"""
|
||||
Retrieve a value from a hash.
|
||||
|
||||
Args:
|
||||
key: Hash key
|
||||
field: Hash field
|
||||
|
||||
Returns:
|
||||
The stored value or None if not found
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
value = await redis.hget(key, field)
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Try to parse as JSON
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
# Return as-is if not JSON
|
||||
return value
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Redis hash {key}:{field}: {str(e)}")
|
||||
return None
|
||||
|
||||
async def delete_hash(self, key: str, field: str) -> bool:
|
||||
"""
|
||||
Remove a value from a hash.
|
||||
|
||||
Args:
|
||||
key: Hash key
|
||||
field: Hash field
|
||||
|
||||
Returns:
|
||||
True if the value was removed, False if it didn't exist
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
result = await redis.hdel(key, field)
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting Redis hash {key}:{field}: {str(e)}")
|
||||
return False
|
||||
|
||||
async def get_all_hash(self, key: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieve all values from a hash.
|
||||
|
||||
Args:
|
||||
key: Hash key
|
||||
|
||||
Returns:
|
||||
Dictionary with all hash values
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
result_dict = await redis.hgetall(key)
|
||||
|
||||
if not result_dict:
|
||||
return {}
|
||||
|
||||
# Try to parse each value as JSON
|
||||
parsed_dict = {}
|
||||
for field, value in result_dict.items():
|
||||
try:
|
||||
parsed_dict[field] = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
parsed_dict[field] = value
|
||||
|
||||
return parsed_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all Redis hash fields for {key}: {str(e)}")
|
||||
return {}
|
||||
|
||||
async def push_list(self, key: str, value: Any) -> int:
|
||||
"""
|
||||
Add a value to the end of a list.
|
||||
|
||||
Args:
|
||||
key: List key
|
||||
value: Value to add
|
||||
|
||||
Returns:
|
||||
Size of the list after the addition
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
|
||||
# Convert dict/list to JSON string
|
||||
if isinstance(value, (dict, list)):
|
||||
value = json.dumps(value)
|
||||
|
||||
return await redis.rpush(key, value)
|
||||
except Exception as e:
|
||||
logger.error(f"Error pushing to Redis list {key}: {str(e)}")
|
||||
return 0
|
||||
|
||||
async def get_list(self, key: str, start: int = 0, end: int = -1) -> List[Any]:
|
||||
"""
|
||||
Retrieve values from a list.
|
||||
|
||||
Args:
|
||||
key: List key
|
||||
start: Initial index (inclusive)
|
||||
end: Final index (inclusive), -1 for all
|
||||
|
||||
Returns:
|
||||
List with the retrieved values
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
values = await redis.lrange(key, start, end)
|
||||
|
||||
if not values:
|
||||
return []
|
||||
|
||||
# Try to parse each value as JSON
|
||||
result = []
|
||||
for value in values:
|
||||
try:
|
||||
result.append(json.loads(value))
|
||||
except json.JSONDecodeError:
|
||||
result.append(value)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Redis list {key}: {str(e)}")
|
||||
return []
|
||||
|
||||
async def expire(self, key: str, ttl: int) -> bool:
|
||||
"""
|
||||
Set a time-to-live for a key.
|
||||
|
||||
Args:
|
||||
key: Key
|
||||
ttl: Time-to-live in seconds
|
||||
|
||||
Returns:
|
||||
True if the key exists and the TTL was set, False otherwise
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
return await redis.expire(key, ttl)
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting expire for Redis key {key}: {str(e)}")
|
||||
return False
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""
|
||||
Clear the entire cache.
|
||||
|
||||
Warning: This is a destructive operation and will remove all data.
|
||||
Only use in development/test environments.
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
await redis.flushdb()
|
||||
logger.warning("Redis database flushed - all data cleared")
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing Redis database: {str(e)}")
|
||||
|
||||
async def keys(self, pattern: str = "*") -> List[str]:
|
||||
"""
|
||||
Retrieve keys that match a pattern.
|
||||
|
||||
Args:
|
||||
pattern: Glob pattern to filter keys
|
||||
|
||||
Returns:
|
||||
List of keys that match the pattern
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
return await redis.keys(pattern)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Redis keys with pattern {pattern}: {str(e)}")
|
||||
return []
|
@ -1,141 +0,0 @@
|
||||
import uuid
|
||||
import json
|
||||
from typing import AsyncGenerator, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
from datetime import datetime
|
||||
from src.schemas.streaming import (
|
||||
JSONRPCRequest,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
from src.services.agent_runner import run_agent
|
||||
from src.services.service_providers import (
|
||||
session_service,
|
||||
artifacts_service,
|
||||
memory_service,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class StreamingService:
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, Any] = {}
|
||||
|
||||
async def send_task_streaming(
|
||||
self,
|
||||
agent_id: str,
|
||||
api_key: str,
|
||||
message: str,
|
||||
contact_id: str = None,
|
||||
session_id: str = None,
|
||||
db: Session = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Starts the SSE event streaming for a task.
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
api_key: API key for authentication
|
||||
message: Initial message
|
||||
contact_id: Contact ID (optional)
|
||||
session_id: Session ID (optional)
|
||||
db: Database session
|
||||
|
||||
Yields:
|
||||
Formatted SSE events
|
||||
"""
|
||||
# Basic API key validation
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=401, detail="API key is required")
|
||||
|
||||
# Generate unique IDs
|
||||
task_id = contact_id or str(uuid.uuid4())
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
# Build JSON-RPC payload
|
||||
payload = JSONRPCRequest(
|
||||
id=request_id,
|
||||
params={
|
||||
"id": task_id,
|
||||
"sessionId": session_id,
|
||||
"message": {
|
||||
"role": "user",
|
||||
"parts": [{"type": "text", "text": message}],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Register connection
|
||||
self.active_connections[task_id] = {
|
||||
"agent_id": agent_id,
|
||||
"api_key": api_key,
|
||||
"session_id": session_id,
|
||||
}
|
||||
|
||||
try:
|
||||
# Send start event
|
||||
yield self._format_sse_event(
|
||||
"status",
|
||||
TaskStatusUpdateEvent(
|
||||
state="working",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
message=payload.params["message"],
|
||||
).model_dump_json(),
|
||||
)
|
||||
|
||||
# Execute the agent
|
||||
result = await run_agent(
|
||||
str(agent_id),
|
||||
contact_id or task_id,
|
||||
message,
|
||||
session_service,
|
||||
artifacts_service,
|
||||
memory_service,
|
||||
db,
|
||||
session_id,
|
||||
)
|
||||
|
||||
# Send the agent's response as a separate event
|
||||
yield self._format_sse_event(
|
||||
"message",
|
||||
json.dumps(
|
||||
{
|
||||
"role": "agent",
|
||||
"content": result,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Completion event
|
||||
yield self._format_sse_event(
|
||||
"status",
|
||||
TaskStatusUpdateEvent(
|
||||
state="completed",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
).model_dump_json(),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Error event
|
||||
yield self._format_sse_event(
|
||||
"status",
|
||||
TaskStatusUpdateEvent(
|
||||
state="failed",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
error={"message": str(e)},
|
||||
).model_dump_json(),
|
||||
)
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Clean connection
|
||||
self.active_connections.pop(task_id, None)
|
||||
|
||||
def _format_sse_event(self, event_type: str, data: str) -> str:
|
||||
"""Format an SSE event."""
|
||||
return f"event: {event_type}\ndata: {data}\n\n"
|
||||
|
||||
async def close_connection(self, task_id: str):
|
||||
"""Close a streaming connection."""
|
||||
if task_id in self.active_connections:
|
||||
self.active_connections.pop(task_id)
|
@ -1,110 +1,28 @@
|
||||
"""
|
||||
A2A protocol utilities.
|
||||
|
||||
This module contains utility functions for the A2A protocol implementation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Any, Dict
|
||||
from src.schemas.a2a import (
|
||||
from src.schemas.a2a_types import (
|
||||
ContentTypeNotSupportedError,
|
||||
UnsupportedOperationError,
|
||||
JSONRPCResponse,
|
||||
UnsupportedOperationError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def are_modalities_compatible(
|
||||
server_output_modes: Optional[List[str]], client_output_modes: Optional[List[str]]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if server and client modalities are compatible.
|
||||
|
||||
Modalities are compatible if they are both non-empty
|
||||
server_output_modes: list[str], client_output_modes: list[str]
|
||||
):
|
||||
"""Modalities are compatible if they are both non-empty
|
||||
and there is at least one common element.
|
||||
|
||||
Args:
|
||||
server_output_modes: List of output modes supported by the server
|
||||
client_output_modes: List of output modes requested by the client
|
||||
|
||||
Returns:
|
||||
True if compatible, False otherwise
|
||||
"""
|
||||
# If client doesn't specify modes, assume all are accepted
|
||||
if client_output_modes is None or len(client_output_modes) == 0:
|
||||
return True
|
||||
|
||||
# If server doesn't specify modes, assume all are supported
|
||||
if server_output_modes is None or len(server_output_modes) == 0:
|
||||
return True
|
||||
|
||||
# Check if there's at least one common mode
|
||||
return any(mode in server_output_modes for mode in client_output_modes)
|
||||
return any(x in server_output_modes for x in client_output_modes)
|
||||
|
||||
|
||||
def new_incompatible_types_error(request_id: str) -> JSONRPCResponse:
|
||||
"""
|
||||
Create a JSON-RPC response for incompatible content types error.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request that caused the error
|
||||
|
||||
Returns:
|
||||
JSON-RPC response with ContentTypeNotSupportedError
|
||||
"""
|
||||
def new_incompatible_types_error(request_id):
|
||||
return JSONRPCResponse(id=request_id, error=ContentTypeNotSupportedError())
|
||||
|
||||
|
||||
def new_not_implemented_error(request_id: str) -> JSONRPCResponse:
|
||||
"""
|
||||
Create a JSON-RPC response for unimplemented operation error.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request that caused the error
|
||||
|
||||
Returns:
|
||||
JSON-RPC response with UnsupportedOperationError
|
||||
"""
|
||||
def new_not_implemented_error(request_id):
|
||||
return JSONRPCResponse(id=request_id, error=UnsupportedOperationError())
|
||||
|
||||
|
||||
def create_task_id(agent_id: str, session_id: str, timestamp: str = None) -> str:
|
||||
"""
|
||||
Create a unique task ID for an agent and session.
|
||||
|
||||
Args:
|
||||
agent_id: The ID of the agent
|
||||
session_id: The ID of the session
|
||||
timestamp: Optional timestamp to include in the ID
|
||||
|
||||
Returns:
|
||||
A unique task ID
|
||||
"""
|
||||
import uuid
|
||||
import time
|
||||
|
||||
timestamp = timestamp or str(int(time.time()))
|
||||
unique = uuid.uuid4().hex[:8]
|
||||
|
||||
return f"{agent_id}_{session_id}_{timestamp}_{unique}"
|
||||
|
||||
|
||||
def format_error_response(error: Exception, request_id: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Format an exception as a JSON-RPC error response.
|
||||
|
||||
Args:
|
||||
error: The exception to format
|
||||
request_id: The ID of the request that caused the error
|
||||
|
||||
Returns:
|
||||
JSON-RPC error response as dictionary
|
||||
"""
|
||||
from src.schemas.a2a import InternalError, JSONRPCResponse
|
||||
|
||||
error_response = JSONRPCResponse(
|
||||
id=request_id, error=InternalError(message=str(error))
|
||||
)
|
||||
|
||||
return error_response.model_dump(exclude_none=True)
|
||||
|
Loading…
Reference in New Issue
Block a user