diff --git a/.env b/.env index e569655a..d0882bd4 100644 --- a/.env +++ b/.env @@ -18,6 +18,9 @@ REDIS_HOST="localhost" REDIS_PORT=6379 REDIS_DB=8 REDIS_PASSWORD="" +REDIS_SSL=false +REDIS_KEY_PREFIX="a2a:" +REDIS_TTL=3600 # Tools cache TTL in seconds (1 hour) TOOLS_CACHE_TTL=3600 diff --git a/.env.example b/.env.example index 7f7f26c9..5aed4c9e 100644 --- a/.env.example +++ b/.env.example @@ -18,6 +18,9 @@ REDIS_HOST="localhost" REDIS_PORT=6379 REDIS_DB=0 REDIS_PASSWORD="your-redis-password" +REDIS_SSL=false +REDIS_KEY_PREFIX="a2a:" +REDIS_TTL=3600 # Tools cache TTL in seconds (1 hour) TOOLS_CACHE_TTL=3600 @@ -44,3 +47,9 @@ ADMIN_INITIAL_PASSWORD="strongpassword123" DEMO_EMAIL="demo@example.com" DEMO_PASSWORD="demo123" DEMO_CLIENT_NAME="Demo Client" + +# A2A settings +A2A_TASK_TTL=3600 +A2A_HISTORY_TTL=86400 +A2A_PUSH_NOTIFICATION_TTL=3600 +A2A_SSE_CLIENT_TTL=300 diff --git a/a2a_checklist.md b/a2a_checklist.md deleted file mode 100644 index 3277bad4..00000000 --- a/a2a_checklist.md +++ /dev/null @@ -1,252 +0,0 @@ -# Checklist de Implementação do Protocolo A2A com Redis - -## 1. Configuração Inicial - -- [ ] **Configurar dependências no arquivo pyproject.toml** - - - Adicionar Redis e dependências relacionadas: - ``` - redis = "^5.3.0" - sse-starlette = "^2.3.3" - jwcrypto = "^1.5.6" - pyjwt = {extras = ["crypto"], version = "^2.10.1"} - ``` - -- [ ] **Configurar variáveis de ambiente para Redis** - - - Adicionar em `.env.example` e `.env`: - ``` - REDIS_HOST=localhost - REDIS_PORT=6379 - REDIS_PASSWORD= - REDIS_DB=0 - REDIS_SSL=false - REDIS_KEY_PREFIX=a2a: - REDIS_TTL=3600 - ``` - -- [ ] **Configurar Redis no docker-compose.yml** - - Adicionar serviço Redis com portas e volumes apropriados - - Configurar segurança básica (senha, se necessário) - -## 2. Implementação de Modelos e Schemas - -- [ ] **Criar schemas A2A em `src/schemas/a2a.py`** - - - Implementar tipos conforme `docs/A2A/samples/python/common/types.py`: - - Enums (TaskState, etc.) - - Classes de mensagens (TextPart, FilePart, etc.) - - Classes de tarefas (Task, TaskStatus, etc.) - - Estruturas JSON-RPC - - Tipos de erros - -- [ ] **Implementar validadores de modelo** - - Validadores para conteúdos de arquivo - - Validadores para formatos de mensagem - - Conversores de formato para compatibilidade com o protocolo - -## 3. Implementação do Cache Redis - -- [ ] **Criar configuração Redis em `src/config/redis.py`** - - - Implementar função de conexão com pool - - Configurar opções de segurança (SSL, autenticação) - - Configurar TTL padrão para diferentes tipos de dados - -- [ ] **Criar serviço de cache Redis em `src/services/redis_cache_service.py`** - - - Implementar métodos do exemplo com suporte a Redis: - ```python - class RedisCache: - async def get_task(self, task_id: str) -> dict - async def save_task(self, task_id: str, task_data: dict, ttl: int = 3600) -> None - async def update_task_status(self, task_id: str, status: dict) -> bool - async def append_to_history(self, task_id: str, message: dict) -> bool - async def save_push_notification_config(self, task_id: str, config: dict) -> None - async def get_push_notification_config(self, task_id: str) -> dict - async def save_sse_client(self, task_id: str, client_id: str) -> None - async def get_sse_clients(self, task_id: str) -> list - async def remove_sse_client(self, task_id: str, client_id: str) -> None - ``` - -- [ ] **Implementar funcionalidades para gerenciamento de conexões** - - Reconexão automática - - Fallback para cache em memória em caso de falha - - Métricas de desempenho - -## 4. Serviços A2A - -- [ ] **Implementar utilitários A2A em `src/utils/a2a_utils.py`** - - - Implementar funções conforme `docs/A2A/samples/python/common/server/utils.py`: - ```python - def are_modalities_compatible(server_output_modes, client_output_modes) - def new_incompatible_types_error(request_id) - def new_not_implemented_error(request_id) - ``` - -- [ ] **Implementar `A2ATaskManager` em `src/services/a2a_task_manager_service.py`** - - - Seguir a interface do `TaskManager` do exemplo - - Implementar todos os métodos abstratos: - ```python - async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse - async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse - async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse - async def on_send_task_subscribe(self, request) -> Union[AsyncIterable, JSONRPCResponse] - async def on_set_task_push_notification(self, request) -> SetTaskPushNotificationResponse - async def on_get_task_push_notification(self, request) -> GetTaskPushNotificationResponse - async def on_resubscribe_to_task(self, request) -> Union[AsyncIterable, JSONRPCResponse] - ``` - - Utilizar Redis para persistência de dados de tarefa - -- [ ] **Implementar `A2AServer` em `src/services/a2a_server_service.py`** - - - Processar requisições JSON-RPC conforme `docs/A2A/samples/python/common/server/server.py`: - ```python - async def _process_request(self, request: Request) - def _handle_exception(self, e: Exception) -> JSONResponse - def _create_response(self, result: Any) -> Union[JSONResponse, EventSourceResponse] - ``` - -- [ ] **Integrar com agent_runner.py existente** - - - Adaptar `run_agent` para uso no contexto de tarefas A2A - - Implementar mapeamento entre formatos de mensagem - -- [ ] **Integrar com streaming_service.py existente** - - Adaptar para formato de eventos compatível com A2A - - Implementar suporte a streaming de múltiplos tipos de eventos - -## 5. Autenticação e Push Notifications - -- [ ] **Implementar `PushNotificationAuth` em `src/services/push_notification_auth_service.py`** - - - Seguir o exemplo em `docs/A2A/samples/python/common/utils/push_notification_auth.py` - - Implementar: - ```python - def generate_jwk(self) - def handle_jwks_endpoint(self, request: Request) - async def send_authenticated_push_notification(self, url: str, data: dict) - ``` - -- [ ] **Implementar verificação de URL de notificação** - - - Seguir método `verify_push_notification_url` do exemplo - - Implementar validação de token para verificação - -- [ ] **Implementar armazenamento seguro de chaves** - - Armazenar chaves privadas de forma segura - - Rotação periódica de chaves - - Gerenciamento do ciclo de vida das chaves - -## 6. Rotas A2A - -- [ ] **Implementar rotas em `src/api/a2a_routes.py`** - - - Criar endpoint principal para processamento de requisições JSON-RPC: - - ```python - @router.post("/{agent_id}") - async def process_a2a_request(agent_id: str, request: Request, x_api_key: str = Header(None)) - ``` - - - Implementar endpoint do Agent Card reutilizando lógica existente: - - ```python - @router.get("/{agent_id}/.well-known/agent.json") - async def get_agent_card(agent_id: str, db: Session = Depends(get_db)) - ``` - - - Implementar endpoint JWKS para autenticação de push notifications: - ```python - @router.get("/{agent_id}/.well-known/jwks.json") - async def get_jwks(agent_id: str, db: Session = Depends(get_db)) - ``` - -- [ ] **Registrar rotas A2A no aplicativo principal** - - Adicionar importação e inclusão em `src/main.py`: - ```python - app.include_router(a2a_routes.router, prefix="/api/v1") - ``` - -## 7. Testes - -- [ ] **Criar testes unitários para schemas A2A** - - - Testar validadores - - Testar conversões de formato - - Testar compatibilidade de modalidades - -- [ ] **Criar testes unitários para cache Redis** - - - Testar todas as operações CRUD - - Testar expiração de dados - - Testar comportamento com falhas de conexão - -- [ ] **Criar testes unitários para gerenciador de tarefas** - - - Testar ciclo de vida da tarefa - - Testar cancelamento de tarefas - - Testar notificações push - -- [ ] **Criar testes de integração para endpoints A2A** - - Testar requisições completas - - Testar streaming - - Testar cenários de erro - -## 8. Segurança - -- [ ] **Implementar validação de API key** - - - Verificar API key para todas as requisições - - Implementar rate limiting por agente/cliente - -- [ ] **Configurar segurança no Redis** - - - Ativar autenticação e SSL em produção - - Definir políticas de retenção de dados - - Implementar backup e recuperação - -- [ ] **Configurar segurança para push notifications** - - Implementar assinatura JWT - - Validar URLs de callback - - Implementar retry com backoff para falhas - -## 9. Monitoramento e Métricas - -- [ ] **Implementar métricas de Redis** - - - Taxa de acertos/erros do cache - - Tempo de resposta - - Uso de memória - -- [ ] **Implementar métricas de tarefas A2A** - - - Número de tarefas por estado - - Tempo médio de processamento - - Taxa de erros - -- [ ] **Configurar logging apropriado** - - Registrar eventos importantes - - Mascarar dados sensíveis - - Implementar níveis de log configuráveis - -## 10. Documentação - -- [ ] **Documentar API A2A** - - - Descrever endpoints e formatos - - Fornecer exemplos de uso - - Documentar erros e soluções - -- [ ] **Documentar integração com Redis** - - - Descrever configuração - - Explicar estratégia de cache - - Documentar TTLs e políticas de expiração - -- [ ] **Criar exemplos de clients** - - Implementar exemplos de uso em Python - - Documentar fluxos comuns - - Fornecer snippets para linguagens populares diff --git a/a2a_client_test.py b/a2a_client_test.py new file mode 100644 index 00000000..96442d25 --- /dev/null +++ b/a2a_client_test.py @@ -0,0 +1,187 @@ +import logging +import httpx +from httpx_sse import connect_sse +from typing import Any, AsyncIterable, Optional +from docs.A2A.samples.python.common.types import ( + AgentCard, + GetTaskRequest, + SendTaskRequest, + SendTaskResponse, + JSONRPCRequest, + GetTaskResponse, + CancelTaskResponse, + CancelTaskRequest, + SetTaskPushNotificationRequest, + SetTaskPushNotificationResponse, + GetTaskPushNotificationRequest, + GetTaskPushNotificationResponse, + A2AClientHTTPError, + A2AClientJSONError, + SendTaskStreamingRequest, + SendTaskStreamingResponse, +) +import json +import asyncio +import uuid + + +# Configurar logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger("a2a_client_runner") + + +class A2ACardResolver: + def __init__(self, base_url, agent_card_path="/.well-known/agent.json"): + self.base_url = base_url.rstrip("/") + self.agent_card_path = agent_card_path.lstrip("/") + + def get_agent_card(self) -> AgentCard: + with httpx.Client() as client: + response = client.get(self.base_url + "/" + self.agent_card_path) + response.raise_for_status() + try: + return AgentCard(**response.json()) + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + + +class A2AClient: + def __init__( + self, + agent_card: AgentCard = None, + url: str = None, + api_key: Optional[str] = None, + ): + if agent_card: + self.url = agent_card.url + elif url: + self.url = url + else: + raise ValueError("Must provide either agent_card or url") + self.api_key = api_key + self.headers = {"x-api-key": api_key} if api_key else {} + + async def send_task(self, payload: dict[str, Any]) -> SendTaskResponse: + request = SendTaskRequest(params=payload) + return SendTaskResponse(**await self._send_request(request)) + + async def send_task_streaming( + self, payload: dict[str, Any] + ) -> AsyncIterable[SendTaskStreamingResponse]: + request = SendTaskStreamingRequest(params=payload) + with httpx.Client(timeout=None) as client: + with connect_sse( + client, + "POST", + self.url, + json=request.model_dump(), + headers=self.headers, + ) as event_source: + try: + for sse in event_source.iter_sse(): + yield SendTaskStreamingResponse(**json.loads(sse.data)) + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError(400, str(e)) from e + + async def _send_request(self, request: JSONRPCRequest) -> dict[str, Any]: + async with httpx.AsyncClient() as client: + try: + # Image generation could take time, adding timeout + response = await client.post( + self.url, + json=request.model_dump(), + headers=self.headers, + timeout=30, + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + + async def get_task(self, payload: dict[str, Any]) -> GetTaskResponse: + request = GetTaskRequest(params=payload) + return GetTaskResponse(**await self._send_request(request)) + + async def cancel_task(self, payload: dict[str, Any]) -> CancelTaskResponse: + request = CancelTaskRequest(params=payload) + return CancelTaskResponse(**await self._send_request(request)) + + async def set_task_callback( + self, payload: dict[str, Any] + ) -> SetTaskPushNotificationResponse: + request = SetTaskPushNotificationRequest(params=payload) + return SetTaskPushNotificationResponse(**await self._send_request(request)) + + async def get_task_callback( + self, payload: dict[str, Any] + ) -> GetTaskPushNotificationResponse: + request = GetTaskPushNotificationRequest(params=payload) + return GetTaskPushNotificationResponse(**await self._send_request(request)) + + +async def main(): + # Configurações + BASE_URL = "http://localhost:8000/api/v1/a2a/18a2889e-8573-4e70-833c-7d9e00a8fd80" + API_KEY = "83c2c19f-dc2e-4abe-9a41-ef7d2eb079d6" + + try: + # Obter o card do agente + logger.info("Obtendo card do agente...") + card_resolver = A2ACardResolver(BASE_URL) + try: + card = card_resolver.get_agent_card() + logger.info(f"Card do agente: {card}") + except Exception as e: + logger.error(f"Erro ao obter card do agente: {e}") + return + + # Criar cliente A2A com API key + client = A2AClient(card, api_key=API_KEY) + + # Exemplo 1: Enviar tarefa síncrona + logger.info("\n=== TESTE DE TAREFA SÍNCRONA ===") + task_id = str(uuid.uuid4()) + session_id = "test-session-1" + + # Preparar payload da tarefa + payload = { + "id": task_id, + "sessionId": session_id, + "message": { + "role": "user", + "parts": [ + { + "type": "text", + "text": "Quais são os três maiores países do mundo em área territorial?", + } + ], + }, + } + + logger.info(f"Enviando tarefa com ID: {task_id}") + async for streaming_response in client.send_task_streaming(payload): + if hasattr(streaming_response.result, "artifact"): + # Processar conteúdo parcial + print(streaming_response.result.artifact.parts[0].text) + elif ( + hasattr(streaming_response.result, "status") + and streaming_response.result.status.state == "completed" + ): + # Tarefa concluída + print( + "Resposta final:", + streaming_response.result.status.message.parts[0].text, + ) + + except Exception as e: + logger.error(f"Erro durante execução dos testes: {e}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/a2a_feature.md b/a2a_feature.md deleted file mode 100644 index c9d1b2f0..00000000 --- a/a2a_feature.md +++ /dev/null @@ -1,491 +0,0 @@ -# Implementação do Servidor A2A (Agent-to-Agent) - -## Visão Geral - -Este documento descreve o plano de implementação para integrar o servidor A2A (Agent-to-Agent) no sistema existente. A implementação seguirá os exemplos fornecidos nos arquivos de referência, adaptando-os à estrutura atual do projeto. - -## Componentes a Serem Implementados - -### 1. Servidor A2A - -Implementação da classe `A2AServer` como serviço para gerenciar requisições JSON-RPC compatíveis com o protocolo A2A. - -### 2. Gerenciador de Tarefas - -Implementação do `TaskManager` como serviço para gerenciar o ciclo de vida das tarefas do agente. - -### 3. Adaptadores para Integração - -Criação de adaptadores para integrar o servidor A2A com os serviços existentes, como o streaming_service e push_notification_service. - -## Rotas e Endpoints A2A - -O protocolo A2A requer a implementação das seguintes rotas JSON-RPC: - -### 1. `POST /a2a/{agent_id}` - -Endpoint principal que processa todas as solicitações JSON-RPC para um agente específico. - -### 2. `GET /a2a/{agent_id}/.well-known/agent.json` - -Retorna o Agent Card contendo as informações do agente. - -### Métodos JSON-RPC implementados: - -1. **tasks/send** - Envia uma nova tarefa para o agente. - - - Parâmetros: `id`, `sessionId`, `message`, `acceptedOutputModes` (opcional), `pushNotification` (opcional) - - Retorna: Status da tarefa, artefatos gerados e histórico - -2. **tasks/sendSubscribe** - Envia uma tarefa e assina para receber atualizações via streaming (SSE). - - - Parâmetros: Mesmos do `tasks/send` - - Retorna: Stream de eventos SSE com atualizações de status e artefatos - -3. **tasks/get** - Obtém o status atual de uma tarefa. - - - Parâmetros: `id`, `historyLength` (opcional) - - Retorna: Status atual da tarefa, artefatos e histórico - -4. **tasks/cancel** - Tenta cancelar uma tarefa em execução. - - - Parâmetros: `id` - - Retorna: Status atualizado da tarefa - -5. **tasks/pushNotification/set** - Configura notificações push para uma tarefa. - - - Parâmetros: `id`, `pushNotificationConfig` (URL e autenticação) - - Retorna: Configuração de notificação atualizada - -6. **tasks/pushNotification/get** - Obtém a configuração de notificações push de uma tarefa. - - - Parâmetros: `id` - - Retorna: Configuração de notificação atual - -7. **tasks/resubscribe** - Reassina para receber eventos de uma tarefa existente. - - Parâmetros: `id` - - Retorna: Stream de eventos SSE - -## Streaming e Push Notifications - -### Streaming (SSE) - -O streaming será implementado usando Server-Sent Events (SSE) para enviar atualizações em tempo real aos clientes. - -#### Integração com streaming_service.py existente - -O serviço atual `StreamingService` já implementa funcionalidades de streaming SSE. Iremos expandir e adaptar: - -```python -# Exemplo de integração com o streaming_service.py existente -async def send_task_streaming(request: SendTaskStreamingRequest) -> AsyncIterable[SendTaskStreamingResponse]: - stream_service = StreamingService() - - async for event in stream_service.send_task_streaming( - agent_id=request.params.metadata.get("agent_id"), - api_key=api_key, - message=request.params.message.parts[0].text, - session_id=request.params.sessionId, - db=db - ): - # Converter formato de evento SSE para formato A2A - yield SendTaskStreamingResponse( - id=request.id, - result=convert_to_a2a_event_format(event) - ) -``` - -### Push Notifications - -O sistema de notificações push permitirá que o agente envie atualizações para URLs de callback configuradas pelos clientes. - -#### Integração com push_notification_service.py existente - -O serviço atual `PushNotificationService` já implementa o envio de notificações. Iremos adaptar: - -```python -# Exemplo de integração com o push_notification_service.py existente -async def send_push_notification(task_id, state, message=None): - notification_config = await get_push_notification_config(task_id) - if notification_config: - await push_notification_service.send_notification( - url=notification_config.url, - task_id=task_id, - state=state, - message=message, - headers=notification_config.headers - ) -``` - -#### Autenticação de Push Notifications - -Implementaremos autenticação segura para as notificações push baseada em JWT usando o `PushNotificationAuth`: - -```python -# Exemplo de como configurar autenticação nas notificações -push_auth = PushNotificationSenderAuth() -push_auth.generate_jwk() - -# Incluir rota para obter as chaves públicas -@router.get("/{agent_id}/.well-known/jwks.json") -async def get_jwks(agent_id: str): - return push_auth.handle_jwks_endpoint(request) - -# Integrar autenticação ao enviar notificações -async def send_authenticated_push_notification(url, data): - await push_auth.send_push_notification(url, data) -``` - -## Estratégia de Armazenamento de Dados - -### Uso de Redis para Dados Temporários - -Utilizaremos Redis para armazenamento e gerenciamento dos dados temporários das tarefas A2A, substituindo o cache em memória do exemplo original: - -```python -from src.services.redis_cache_service import RedisCacheService - -class RedisCache: - def __init__(self, redis_service: RedisCacheService): - self.redis = redis_service - - # Métodos para gerenciamento de tarefas - async def get_task(self, task_id: str) -> dict: - """Recupera uma tarefa pelo ID.""" - return self.redis.get(f"task:{task_id}") - - async def save_task(self, task_id: str, task_data: dict, ttl: int = 3600) -> None: - """Salva uma tarefa com TTL configurável.""" - self.redis.set(f"task:{task_id}", task_data, ttl=ttl) - - async def update_task_status(self, task_id: str, status: dict) -> bool: - """Atualiza o status de uma tarefa.""" - task_data = await self.get_task(task_id) - if not task_data: - return False - task_data["status"] = status - await self.save_task(task_id, task_data) - return True - - # Métodos para histórico de tarefas - async def append_to_history(self, task_id: str, message: dict) -> bool: - """Adiciona uma mensagem ao histórico da tarefa.""" - task_data = await self.get_task(task_id) - if not task_data: - return False - - if "history" not in task_data: - task_data["history"] = [] - - task_data["history"].append(message) - await self.save_task(task_id, task_data) - return True - - # Métodos para notificações push - async def save_push_notification_config(self, task_id: str, config: dict) -> None: - """Salva a configuração de notificação push para uma tarefa.""" - self.redis.set(f"push_notification:{task_id}", config, ttl=3600) - - async def get_push_notification_config(self, task_id: str) -> dict: - """Recupera a configuração de notificação push de uma tarefa.""" - return self.redis.get(f"push_notification:{task_id}") - - # Métodos para SSE (Server-Sent Events) - async def save_sse_client(self, task_id: str, client_id: str) -> None: - """Registra um cliente SSE para uma tarefa.""" - self.redis.set_hash(f"sse_clients:{task_id}", client_id, "active") - - async def get_sse_clients(self, task_id: str) -> list: - """Recupera todos os clientes SSE registrados para uma tarefa.""" - return self.redis.get_all_hash(f"sse_clients:{task_id}") - - async def remove_sse_client(self, task_id: str, client_id: str) -> None: - """Remove um cliente SSE do registro.""" - self.redis.delete_hash(f"sse_clients:{task_id}", client_id) -``` - -O serviço Redis será configurado com TTL (time-to-live) para garantir a limpeza automática de dados temporários: - -```python -# Configuração de TTL para diferentes tipos de dados -TASK_TTL = 3600 # 1 hora para tarefas -HISTORY_TTL = 86400 # 24 horas para histórico -PUSH_NOTIFICATION_TTL = 3600 # 1 hora para configurações de notificação -SSE_CLIENT_TTL = 300 # 5 minutos para clientes SSE -``` - -### Modelos Existentes e Redis - -O sistema continuará utilizando os modelos SQLAlchemy existentes para dados permanentes: - -- **Agent**: Dados do agente e configurações -- **Session**: Sessões persistentes -- **MCPServer**: Configurações de servidores de ferramentas - -Para dados temporários (tarefas A2A, histórico, streaming), utilizaremos Redis que oferece: - -1. **Performance**: Operações em memória com persistência opcional -2. **TTL**: Expiração automática de dados temporários -3. **Estruturas de dados**: Suporte a strings, hashes, listas para diferentes necessidades -4. **Pub/Sub**: Mecanismo para notificações em tempo real -5. **Escalabilidade**: Melhor suporte a múltiplas instâncias do que cache em memória - -### Implementação do TaskManager com Redis - -O `A2ATaskManager` implementará a mesma interface que o `TaskManager` do exemplo, mas utilizando Redis: - -```python -class A2ATaskManager: - """ - Gerenciador de tarefas A2A usando Redis para armazenamento. - Implementa a interface do protocolo A2A para gerenciamento do ciclo de vida das tarefas. - """ - - def __init__( - self, - redis_cache: RedisCacheService, - session_service=None, - artifacts_service=None, - memory_service=None, - push_notification_service=None - ): - self.redis_cache = redis_cache - self.session_service = session_service - self.artifacts_service = artifacts_service - self.memory_service = memory_service - self.push_notification_service = push_notification_service - self.lock = asyncio.Lock() - self.subscriber_lock = asyncio.Lock() - self.task_sse_subscribers = {} - - async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse: - """ - Obtém o status atual de uma tarefa. - - Args: - request: Requisição JSON-RPC para obter dados da tarefa - - Returns: - Resposta com dados da tarefa ou erro - """ - logger.info(f"Getting task {request.params.id}") - task_query_params = request.params - - task_data = self.redis_cache.get(f"task:{task_query_params.id}") - if not task_data: - return GetTaskResponse(id=request.id, error=TaskNotFoundError()) - - # Processar histórico conforme solicitado - if task_query_params.historyLength and task_data.get("history"): - task_data["history"] = task_data["history"][-task_query_params.historyLength:] - - return GetTaskResponse(id=request.id, result=task_data) -``` - -## Implementação do A2A Server Service - -O serviço `A2AServer` processará as requisições JSON-RPC conforme o protocolo A2A: - -```python -class A2AServer: - """ - Servidor A2A que implementa o protocolo JSON-RPC para processamento de tarefas de agentes. - """ - - def __init__( - self, - endpoint: str = "/", - agent_card = None, - task_manager = None, - streaming_service = None, - ): - self.endpoint = endpoint - self.agent_card = agent_card - self.task_manager = task_manager - self.streaming_service = streaming_service - - async def _process_request(self, request: Request): - """ - Processa uma requisição JSON-RPC do protocolo A2A. - - Args: - request: Requisição HTTP - - Returns: - Resposta JSON-RPC ou stream de eventos - """ - try: - body = await request.json() - json_rpc_request = A2ARequest.validate_python(body) - - # Delegar para o handler apropriado com base no tipo de requisição - if isinstance(json_rpc_request, GetTaskRequest): - result = await self.task_manager.on_get_task(json_rpc_request) - elif isinstance(json_rpc_request, SendTaskRequest): - result = await self.task_manager.on_send_task(json_rpc_request) - elif isinstance(json_rpc_request, SendTaskStreamingRequest): - result = await self.task_manager.on_send_task_subscribe(json_rpc_request) - elif isinstance(json_rpc_request, CancelTaskRequest): - result = await self.task_manager.on_cancel_task(json_rpc_request) - elif isinstance(json_rpc_request, SetTaskPushNotificationRequest): - result = await self.task_manager.on_set_task_push_notification(json_rpc_request) - elif isinstance(json_rpc_request, GetTaskPushNotificationRequest): - result = await self.task_manager.on_get_task_push_notification(json_rpc_request) - elif isinstance(json_rpc_request, TaskResubscriptionRequest): - result = await self.task_manager.on_resubscribe_to_task(json_rpc_request) - else: - logger.warning(f"Unexpected request type: {type(json_rpc_request)}") - raise ValueError(f"Unexpected request type: {type(json_rpc_request)}") - - return self._create_response(result) - - except Exception as e: - return self._handle_exception(e) -``` - -## Implementação da Autenticação para Push Notifications - -Implementaremos autenticação JWT para notificações push, seguindo o exemplo: - -```python -class PushNotificationAuth: - def __init__(self): - self.public_keys = [] - self.private_key_jwk = None - - def generate_jwk(self): - key = jwk.JWK.generate(kty='RSA', size=2048, kid=str(uuid.uuid4()), use="sig") - self.public_keys.append(key.export_public(as_dict=True)) - self.private_key_jwk = PyJWK.from_json(key.export_private()) - - def handle_jwks_endpoint(self, request: Request): - """Retorna as chaves públicas para clientes.""" - return JSONResponse({"keys": self.public_keys}) - - async def send_authenticated_push_notification(self, url: str, data: dict): - """Envia notificação push assinada com JWT.""" - jwt_token = self._generate_jwt(data) - headers = {'Authorization': f"Bearer {jwt_token}"} - async with httpx.AsyncClient(timeout=10) as client: - try: - response = await client.post(url, json=data, headers=headers) - response.raise_for_status() - logger.info(f"Push-notification sent to URL: {url}") - except Exception as e: - logger.warning(f"Error sending push-notification to URL {url}: {e}") -``` - -## Revisão das Rotas A2A - -As rotas A2A serão implementadas em `src/api/a2a_routes.py`, utilizando a lógica de AgentCard do código existente: - -```python -@router.post("/{agent_id}") -async def process_a2a_request( - agent_id: str, - request: Request, - x_api_key: str = Header(None, alias="x-api-key"), - db: Session = Depends(get_db), -): - """ - Endpoint que processa requisições JSON-RPC do protocolo A2A. - """ - # Validar agente e API key - agent = get_agent(db, agent_id) - if not agent: - return JSONResponse( - status_code=404, - content={"detail": "Agente não encontrado"} - ) - - if agent.config.get("api_key") != x_api_key: - return JSONResponse( - status_code=401, - content={"detail": "Chave API inválida"} - ) - - # Criar Agent Card para o agente (reutilizando lógica existente) - agent_card = create_agent_card_from_agent(agent, db) - - # Configurar o servidor A2A para este agente - a2a_server.agent_card = agent_card - - # Processar a requisição A2A - return await a2a_server._process_request(request) -``` - -## Arquivos a Serem Criados/Atualizados - -### Novos Arquivos - -1. `src/schemas/a2a.py` - Modelos Pydantic para o protocolo A2A -2. `src/services/redis_cache_service.py` - Serviço de cache Redis -3. `src/config/redis.py` - Configuração do cliente Redis -4. `src/utils/a2a_utils.py` - Utilitários para o protocolo A2A -5. `src/services/a2a_task_manager_service.py` - Gerenciador de tarefas A2A -6. `src/services/a2a_server_service.py` - Servidor A2A -7. `src/services/push_notification_auth_service.py` - Autenticação para push notifications -8. `src/api/a2a_routes.py` - Rotas para o protocolo A2A - -### Arquivos a Serem Atualizados - -1. `src/main.py` - Registrar novas rotas A2A -2. `pyproject.toml` - Adicionar dependências (Redis, jwcrypto, etc.) - -## Plano de Implementação - -### Fase 1: Criação dos Esquemas - -1. Criar arquivo `src/schemas/a2a.py` com os modelos Pydantic baseados no arquivo `common/types.py` -2. Adaptar os tipos para a estrutura do projeto e adicionar suporte para streaming e push notifications - -### Fase 2: Implementação do Serviço de Cache Redis - -1. Criar arquivo `src/config/redis.py` para configuração do cliente Redis -2. Criar arquivo `src/services/redis_cache_service.py` para gerenciamento de cache - -### Fase 3: Implementação de Utilitários - -1. Criar arquivo `src/utils/a2a_utils.py` com funções utilitárias baseadas em `common/server/utils.py` -2. Adaptar o `PushNotificationAuth` para uso no contexto A2A - -### Fase 4: Implementação do Gerenciador de Tarefas - -1. Criar arquivo `src/services/a2a_task_manager_service.py` com a implementação do `A2ATaskManager` -2. Integrar com serviços existentes: - - agent_runner.py para execução de agentes - - streaming_service.py para streaming SSE - - push_notification_service.py para push notifications - - redis_cache_service.py para cache de tarefas - -### Fase 5: Implementação do Servidor A2A - -1. Criar arquivo `src/services/a2a_server_service.py` com a implementação do `A2AServer` -2. Implementar processamento de requisições JSON-RPC para todas as operações A2A - -### Fase 6: Integração - -1. Criar arquivo `src/api/a2a_routes.py` com rotas para o protocolo A2A -2. Registrar as rotas no aplicativo FastAPI principal -3. Assegurar que todas as operações A2A funcionem corretamente, incluindo streaming e push notifications - -## Adaptações Necessárias - -1. **Esquemas**: Adaptar os modelos do protocolo A2A para usar os esquemas Pydantic existentes quando possível -2. **Autenticação**: Integrar com o sistema de autenticação existente usando API keys -3. **Streaming**: Adaptar o `StreamingService` existente para o formato de eventos A2A -4. **Push Notifications**: Integrar o `PushNotificationService` existente e adicionar suporte a autenticação JWT -5. **Cache**: Utilizar Redis para armazenamento temporário de tarefas e eventos -6. **Execução de Agentes**: Reutilizar o serviço existente `agent_runner.py` para execução - -## Próximos Passos - -1. Configurar dependências do Redis no projeto -2. Implementar os esquemas em `src/schemas/a2a.py` -3. Implementar o serviço de cache Redis -4. Implementar as funções utilitárias em `src/utils/a2a_utils.py` -5. Implementar o gerenciador de tarefas em `src/services/a2a_task_manager_service.py` -6. Implementar o servidor A2A em `src/services/a2a_server_service.py` -7. Implementar as rotas em `src/api/a2a_routes.py` -8. Registrar as rotas no aplicativo principal `src/main.py` -9. Testar a implementação com casos de uso completos, incluindo streaming e push notifications diff --git a/docker-compose.yml b/docker-compose.yml index fbb7d33a..9eb003d0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,14 +10,21 @@ services: ports: - "8000:8000" environment: - - POSTGRES_CONNECTION_STRING=postgresql://postgres:postgres@postgres:5432/evo_ai - - REDIS_HOST=redis - - REDIS_PORT=6379 - - REDIS_PASSWORD= - - JWT_SECRET_KEY=${JWT_SECRET_KEY} - - SENDGRID_API_KEY=${SENDGRID_API_KEY} - - EMAIL_FROM=${EMAIL_FROM} - - APP_URL=${APP_URL} + POSTGRES_CONNECTION_STRING: postgresql://postgres:postgres@postgres:5432/evo_ai + REDIS_HOST: redis + REDIS_PORT: 6379 + REDIS_PASSWORD: "" + REDIS_SSL: "false" + REDIS_KEY_PREFIX: "a2a:" + REDIS_TTL: 3600 + A2A_TASK_TTL: 3600 + A2A_HISTORY_TTL: 86400 + A2A_PUSH_NOTIFICATION_TTL: 3600 + A2A_SSE_CLIENT_TTL: 300 + JWT_SECRET_KEY: ${JWT_SECRET_KEY} + SENDGRID_API_KEY: ${SENDGRID_API_KEY} + EMAIL_FROM: ${EMAIL_FROM} + APP_URL: ${APP_URL} volumes: - ./logs:/app/logs restart: unless-stopped @@ -26,9 +33,9 @@ services: image: postgres:14-alpine container_name: evo-ai-postgres environment: - - POSTGRES_USER=postgres - - POSTGRES_PASSWORD=postgres - - POSTGRES_DB=evo_ai + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: evo_ai ports: - "5432:5432" volumes: @@ -42,6 +49,12 @@ services: - "6379:6379" volumes: - redis_data:/data + command: redis-server --appendonly yes + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 5s + timeout: 30s + retries: 50 restart: unless-stopped volumes: diff --git a/pyproject.toml b/pyproject.toml index 5a1b2682..8f1751f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,10 @@ dependencies = [ "pydantic[email]==2.11.3", "httpx==0.28.1", "httpx-sse==0.4.0", + "redis==5.3.0", + "sse-starlette==2.3.3", + "jwcrypto==1.5.6", + "pyjwt[crypto]==2.9.0", ] [project.optional-dependencies] diff --git a/src/api/a2a_routes.py b/src/api/a2a_routes.py new file mode 100644 index 00000000..a90b7295 --- /dev/null +++ b/src/api/a2a_routes.py @@ -0,0 +1,395 @@ +""" +Routes for the A2A (Agent-to-Agent) protocol. + +This module implements the standard A2A routes according to the specification. +""" + +import uuid +import logging +from fastapi import APIRouter, Depends, HTTPException, status, Header, Request +from sqlalchemy.orm import Session +from starlette.responses import JSONResponse + +from src.config.database import get_db +from src.services import agent_service +from src.services import ( + RedisCacheService, + AgentRunnerAdapter, + StreamingServiceAdapter, + create_agent_card_from_agent, +) +from src.services.a2a_task_manager_service import A2ATaskManager +from src.services.a2a_server_service import A2AServer +from src.services.agent_runner import run_agent +from src.services.service_providers import ( + session_service, + artifacts_service, + memory_service, +) +from src.services.push_notification_service import push_notification_service +from src.services.push_notification_auth_service import push_notification_auth +from src.services.streaming_service import StreamingService + +logger = logging.getLogger(__name__) + +# Create router with prefix /a2a according to the standard protocol +router = APIRouter( + prefix="/a2a", + tags=["a2a"], + responses={ + 404: {"description": "Not found"}, + 400: {"description": "Bad request"}, + 401: {"description": "Unauthorized"}, + 500: {"description": "Internal server error"}, + }, +) + +# Singleton instances for shared resources +streaming_service = StreamingService() +redis_cache_service = RedisCacheService() +streaming_adapter = StreamingServiceAdapter(streaming_service) + +# Cache dictionary para manter instâncias de A2ATaskManager por agente +# Isso evita criar novas instâncias a cada request +_task_manager_cache = {} +_agent_runner_cache = {} + + +def get_agent_runner_adapter(db=None, reuse=True, agent_id=None): + """ + Get or create an agent runner adapter. + + Args: + db: Database session + reuse: Whether to reuse an existing instance + agent_id: Agent ID to use as cache key + + Returns: + Agent runner adapter instance + """ + cache_key = str(agent_id) if agent_id else "default" + logger.info( + f"[DEBUG] get_agent_runner_adapter chamado para agent_id={agent_id}, reuse={reuse}, cache_key={cache_key}" + ) + + if reuse and cache_key in _agent_runner_cache: + adapter = _agent_runner_cache[cache_key] + logger.info( + f"[DEBUG] Reutilizando agent_runner_adapter existente para {cache_key}" + ) + # Atualizar a sessão DB se fornecida + if db is not None: + adapter.db = db + return adapter + + logger.info( + f"[IMPORTANTE] Criando NOVA instância de AgentRunnerAdapter para {cache_key}" + ) + adapter = AgentRunnerAdapter( + agent_runner_func=run_agent, + session_service=session_service, + artifacts_service=artifacts_service, + memory_service=memory_service, + db=db, + ) + + if reuse: + logger.info(f"[DEBUG] Armazenando nova instância no cache para {cache_key}") + _agent_runner_cache[cache_key] = adapter + + return adapter + + +def get_task_manager(agent_id, db=None, reuse=True, operation_type="query"): + cache_key = str(agent_id) + + # Para operações de consulta, NUNCA crie um agent_runner + if operation_type == "query": + if cache_key in _task_manager_cache: + # Reutilize existente + task_manager = _task_manager_cache[cache_key] + task_manager.db = db + return task_manager + + # Se não existe, crie um task_manager SEM agent_runner para consultas + return A2ATaskManager( + redis_cache=redis_cache_service, + agent_runner=None, # Sem agent_runner para consultas! + streaming_service=streaming_adapter, + push_notification_service=push_notification_service, + db=db, + ) + + # Para operações de execução, use o fluxo normal + if reuse and cache_key in _task_manager_cache: + # Atualize o db + task_manager = _task_manager_cache[cache_key] + task_manager.db = db + return task_manager + + # Create new + agent_runner_adapter = get_agent_runner_adapter( + db=db, reuse=reuse, agent_id=agent_id + ) + task_manager = A2ATaskManager( + redis_cache=redis_cache_service, + agent_runner=agent_runner_adapter, + streaming_service=streaming_adapter, + push_notification_service=push_notification_service, + db=db, + ) + _task_manager_cache[cache_key] = task_manager + return task_manager + + +@router.post("/{agent_id}") +async def process_a2a_request( + agent_id: uuid.UUID, + request: Request, + x_api_key: str = Header(None, alias="x-api-key"), + db: Session = Depends(get_db), +): + """ + Main endpoint for processing JSON-RPC requests of the A2A protocol. + + This endpoint processes all JSON-RPC methods of the A2A protocol, including: + - tasks/send: Sending tasks + - tasks/sendSubscribe: Sending tasks with streaming + - tasks/get: Querying task status + - tasks/cancel: Cancelling tasks + - tasks/pushNotification/set: Setting push notifications + - tasks/pushNotification/get: Querying push notification configurations + - tasks/resubscribe: Resubscribing to receive task updates + + Args: + agent_id: Agent ID + request: HTTP request with JSON-RPC payload + x_api_key: API key for authentication + db: Database session + + Returns: + JSON-RPC response or streaming (SSE) depending on the method + """ + try: + # Detailed request log + logger.info(f"Request received for A2A agent {agent_id}") + logger.info(f"Headers: {dict(request.headers)}") + + try: + body = await request.json() + method = body.get("method", "unknown") + logger.info(f"[IMPORTANTE] Método solicitado: {method}") + logger.info(f"Request body: {body}") + + # Determinar se é uma solicitação de consulta (get_task) ou execução (send_task) + is_query_request = method in [ + "tasks/get", + "tasks/cancel", + "tasks/pushNotification/get", + "tasks/resubscribe", + ] + # Para consultas, reutilizamos os componentes; para execuções, + # criamos novos para garantir estado limpo + reuse_components = is_query_request + logger.info( + f"[IMPORTANTE] Is query request: {is_query_request}, Reuse components: {reuse_components}" + ) + + except Exception as e: + logger.error(f"Error reading request body: {e}") + return JSONResponse( + status_code=400, + content={ + "jsonrpc": "2.0", + "id": None, + "error": { + "code": -32700, + "message": f"Parse error: {str(e)}", + "data": None, + }, + }, + ) + + # Verify if the agent exists + agent = agent_service.get_agent(db, agent_id) + if agent is None: + logger.warning(f"Agent not found: {agent_id}") + return JSONResponse( + status_code=404, + content={ + "jsonrpc": "2.0", + "id": None, + "error": {"code": 404, "message": "Agent not found", "data": None}, + }, + ) + + # Verify API key + agent_config = agent.config + logger.info(f"Received API Key: {x_api_key}") + logger.info(f"Expected API Key: {agent_config.get('api_key')}") + + if x_api_key and agent_config.get("api_key") != x_api_key: + logger.warning(f"Invalid API Key for agent {agent_id}") + return JSONResponse( + status_code=401, + content={ + "jsonrpc": "2.0", + "id": None, + "error": {"code": 401, "message": "Invalid API key", "data": None}, + }, + ) + + # Obter o task manager para este agente (reutilizando se possível) + a2a_task_manager = get_task_manager( + agent_id, + db=db, + reuse=reuse_components, + operation_type="query" if is_query_request else "execution", + ) + a2a_server = A2AServer(task_manager=a2a_task_manager) + + # Configure agent_card for the A2A server + logger.info("Configuring agent_card for A2A server") + agent_card = create_agent_card_from_agent(agent, db) + a2a_server.agent_card = agent_card + + # Verify JSON-RPC format + if not body.get("jsonrpc") or body.get("jsonrpc") != "2.0": + logger.error(f"Invalid JSON-RPC format: {body.get('jsonrpc')}") + return JSONResponse( + status_code=400, + content={ + "jsonrpc": "2.0", + "id": body.get("id"), + "error": { + "code": -32600, + "message": "Invalid Request: jsonrpc must be '2.0'", + "data": None, + }, + }, + ) + + # Verify the method + if not body.get("method"): + logger.error("Method not specified in request") + return JSONResponse( + status_code=400, + content={ + "jsonrpc": "2.0", + "id": body.get("id"), + "error": { + "code": -32600, + "message": "Invalid Request: method is required", + "data": None, + }, + }, + ) + + logger.info(f"Processing method: {body.get('method')}") + + # Process the request with the A2A server + logger.info("Sending request to A2A server") + + # Pass the agent_id and db directly to the process_request method + return await a2a_server.process_request(request, agent_id=str(agent_id), db=db) + + except Exception as e: + logger.error(f"Error processing A2A request: {str(e)}", exc_info=True) + return JSONResponse( + status_code=500, + content={ + "jsonrpc": "2.0", + "id": None, + "error": { + "code": -32603, + "message": "Internal server error", + "data": {"detail": str(e)}, + }, + }, + ) + + +@router.get("/{agent_id}/.well-known/agent.json") +async def get_agent_card( + agent_id: uuid.UUID, + request: Request, + db: Session = Depends(get_db), +): + """ + Endpoint to get the Agent Card in the .well-known format of the A2A protocol. + + This endpoint returns the agent information in the standard A2A format, + including capabilities, authentication information, and skills. + + Args: + agent_id: Agent ID + request: HTTP request + db: Database session + + Returns: + Agent Card in JSON format + """ + try: + agent = agent_service.get_agent(db, agent_id) + if agent is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found" + ) + + agent_card = create_agent_card_from_agent(agent, db) + + # Obter o task manager para este agente (reutilizando se possível) + a2a_task_manager = get_task_manager(agent_id, db=db, reuse=True) + a2a_server = A2AServer(task_manager=a2a_task_manager) + + # Configure the A2A server with the agent card + a2a_server.agent_card = agent_card + + # Use the A2A server to deliver the agent card, ensuring protocol compatibility + return await a2a_server.get_agent_card(request, db=db) + + except Exception as e: + logger.error(f"Error generating agent card: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error generating agent card", + ) + + +@router.get("/{agent_id}/.well-known/jwks.json") +async def get_jwks( + agent_id: uuid.UUID, + request: Request, + db: Session = Depends(get_db), +): + """ + Endpoint to get the public JWKS keys for verifying the authenticity + of push notifications. + + Clients can use these keys to verify the authenticity of received notifications. + + Args: + agent_id: Agent ID + request: HTTP request + db: Database session + + Returns: + JSON with the public keys in JWKS format + """ + try: + # Verify if the agent exists + agent = agent_service.get_agent(db, agent_id) + if agent is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found" + ) + + # Return the public keys + return push_notification_auth.handle_jwks_endpoint(request) + + except Exception as e: + logger.error(f"Error obtaining JWKS: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error obtaining JWKS", + ) diff --git a/src/api/agent_routes.py b/src/api/agent_routes.py index 19ab00e3..4ac0f072 100644 --- a/src/api/agent_routes.py +++ b/src/api/agent_routes.py @@ -1,6 +1,3 @@ -from datetime import datetime -import asyncio -import os from fastapi import APIRouter, Depends, HTTPException, status, Header from sqlalchemy.orm import Session from src.config.database import get_db @@ -18,19 +15,7 @@ from src.services import ( agent_service, mcp_server_service, ) -from src.services.agent_runner import run_agent -from src.services.service_providers import ( - session_service, - artifacts_service, - memory_service, -) -from src.services.push_notification_service import push_notification_service import logging -from fastapi.responses import StreamingResponse -from ..services.streaming_service import StreamingService -from ..schemas.streaming import JSONRPCRequest - -from src.services.session_service import get_session_events logger = logging.getLogger(__name__) @@ -79,8 +64,6 @@ router = APIRouter( responses={404: {"description": "Not found"}}, ) -streaming_service = StreamingService() - @router.post("/", response_model=Agent, status_code=status.HTTP_201_CREATED) async def create_agent( @@ -172,348 +155,3 @@ async def delete_agent( raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found" ) - - -@router.get("/{agent_id}/.well-known/agent.json") -async def get_agent_json( - agent_id: uuid.UUID, - db: Session = Depends(get_db), -): - try: - agent = agent_service.get_agent(db, agent_id) - if agent is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found" - ) - - mcp_servers = agent.config.get("mcp_servers", []) - formatted_tools = await format_agent_tools(mcp_servers, db) - - AGENT_CARD = { - "name": agent.name, - "description": agent.description, - "url": f"{os.getenv('API_URL', '')}/api/v1/agents/{agent.id}", - "provider": { - "organization": os.getenv("ORGANIZATION_NAME", ""), - "url": os.getenv("ORGANIZATION_URL", ""), - }, - "version": os.getenv("API_VERSION", ""), - "capabilities": { - "streaming": True, - "pushNotifications": True, - "stateTransitionHistory": True, - }, - "authentication": { - "schemes": ["apiKey"], - "credentials": {"in": "header", "name": "x-api-key"}, - }, - "defaultInputModes": ["text", "application/json"], - "defaultOutputModes": ["text", "application/json"], - "skills": formatted_tools, - } - return AGENT_CARD - except Exception as e: - logger.error(f"Error generating agent card: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error generating agent card", - ) - - -@router.post("/{agent_id}/tasks/send") -async def handle_task( - agent_id: uuid.UUID, - request: JSONRPCRequest, - x_api_key: str = Header(..., alias="x-api-key"), - db: Session = Depends(get_db), -): - """Endpoint to clients A2A send a new task (with an initial user message).""" - try: - # Verify agent - agent = agent_service.get_agent(db, agent_id) - if agent is None: - return { - "jsonrpc": "2.0", - "id": request.id, - "error": {"code": 404, "message": "Agent not found", "data": None}, - } - - # Verify API key - agent_config = agent.config - if agent_config.get("api_key") != x_api_key: - return { - "jsonrpc": "2.0", - "id": request.id, - "error": {"code": 401, "message": "Invalid API key", "data": None}, - } - - # Extract task request from JSON-RPC params - task_request = request.params - - # Validate required fields - task_id = task_request.get("id") - if not task_id: - return { - "jsonrpc": "2.0", - "id": request.id, - "error": { - "code": -32602, - "message": "Invalid parameters", - "data": {"detail": "Task ID is required"}, - }, - } - - # Extract user message - try: - user_message = task_request["message"]["parts"][0]["text"] - except (KeyError, IndexError): - return { - "jsonrpc": "2.0", - "id": request.id, - "error": { - "code": -32602, - "message": "Invalid parameters", - "data": {"detail": "Invalid message format"}, - }, - } - - # Configure session and metadata - session_id = f"{task_id}_{agent_id}" - metadata = task_request.get("metadata", {}) - history_length = metadata.get("historyLength", 50) - - # Initialize response - response_task = { - "id": task_id, - "sessionId": session_id, - "status": { - "state": "submitted", - "timestamp": datetime.now().isoformat(), - "message": None, - "error": None, - }, - "artifacts": [], - "history": [], - "metadata": metadata, - } - - # Handle push notification configuration - push_notification = task_request.get("pushNotification") - if push_notification: - url = push_notification.get("url") - headers = push_notification.get("headers", {}) - - if not url: - return { - "jsonrpc": "2.0", - "id": request.id, - "error": { - "code": -32602, - "message": "Invalid parameters", - "data": {"detail": "Push notification URL is required"}, - }, - } - - # Store push notification config in metadata - response_task["metadata"]["pushNotification"] = { - "url": url, - "headers": headers, - } - - # Send initial notification - asyncio.create_task( - push_notification_service.send_notification( - url=url, task_id=task_id, state="submitted", headers=headers - ) - ) - - try: - # Update status to running - response_task["status"].update( - {"state": "running", "timestamp": datetime.now().isoformat()} - ) - - # Send running notification if configured - if push_notification: - asyncio.create_task( - push_notification_service.send_notification( - url=url, task_id=task_id, state="running", headers=headers - ) - ) - - # Execute agent - final_response_text = await run_agent( - str(agent_id), - task_id, - user_message, - session_service, - artifacts_service, - memory_service, - db, - session_id, - ) - - # Update status to completed - response_task["status"].update( - { - "state": "completed", - "timestamp": datetime.now().isoformat(), - "message": { - "role": "agent", - "parts": [{"type": "text", "text": final_response_text}], - }, - } - ) - - # Add artifacts - if final_response_text: - response_task["artifacts"].append( - { - "type": "text", - "content": final_response_text, - "metadata": { - "generated_at": datetime.now().isoformat(), - "content_type": "text/plain", - }, - } - ) - - # Send completed notification if configured - if push_notification: - asyncio.create_task( - push_notification_service.send_notification( - url=url, - task_id=task_id, - state="completed", - message={ - "role": "agent", - "parts": [{"type": "text", "text": final_response_text}], - }, - headers=headers, - ) - ) - - except Exception as e: - # Update status to failed - response_task["status"].update( - { - "state": "failed", - "timestamp": datetime.now().isoformat(), - "error": {"code": "AGENT_EXECUTION_ERROR", "message": str(e)}, - } - ) - - # Send failed notification if configured - if push_notification: - asyncio.create_task( - push_notification_service.send_notification( - url=url, - task_id=task_id, - state="failed", - message={ - "role": "system", - "parts": [{"type": "text", "text": str(e)}], - }, - headers=headers, - ) - ) - - # Process history - try: - history_messages = get_session_events(session_service, session_id) - history_messages = history_messages[-history_length:] - - formatted_history = [] - for event in history_messages: - if event.content and event.content.parts: - role = ( - "agent" if event.content.role == "model" else event.content.role - ) - formatted_history.append( - { - "role": role, - "parts": [ - {"type": "text", "text": part.text} - for part in event.content.parts - if part.text - ], - } - ) - - response_task["history"] = formatted_history - - except Exception as e: - logger.error(f"Error processing history: {str(e)}") - - # Return JSON-RPC response - return {"jsonrpc": "2.0", "id": request.id, "result": response_task} - - except HTTPException as e: - return { - "jsonrpc": "2.0", - "id": request.id, - "error": {"code": e.status_code, "message": e.detail, "data": None}, - } - except Exception as e: - logger.error(f"Unexpected error in handle_task: {str(e)}") - return { - "jsonrpc": "2.0", - "id": request.id, - "error": { - "code": -32603, - "message": "Internal server error", - "data": {"detail": str(e)}, - }, - } - - -@router.post("/{agent_id}/tasks/sendSubscribe") -async def subscribe_task_streaming( - agent_id: str, - request: JSONRPCRequest, - x_api_key: str = Header(None), - db: Session = Depends(get_db), -): - """ - Endpoint para streaming de eventos SSE de uma tarefa. - - Args: - agent_id: ID do agente - request: Requisição JSON-RPC - x_api_key: Chave de API no header - db: Sessão do banco de dados - - Returns: - StreamingResponse com eventos SSE - """ - if not x_api_key: - return { - "jsonrpc": "2.0", - "id": request.id, - "error": {"code": 401, "message": "API key é obrigatória", "data": None}, - } - - # Extrai mensagem do payload - message = request.params.get("message", {}).get("parts", [{}])[0].get("text", "") - session_id = request.params.get("sessionId") - - # Configura streaming - async def event_generator(): - async for event in streaming_service.send_task_streaming( - agent_id=agent_id, - api_key=x_api_key, - message=message, - session_id=session_id, - db=db, - ): - yield event - - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) diff --git a/src/config/redis.py b/src/config/redis.py new file mode 100644 index 00000000..29167e52 --- /dev/null +++ b/src/config/redis.py @@ -0,0 +1,85 @@ +""" +Redis configuration module. + +This module defines the Redis connection settings and provides +function to create a Redis connection pool for the application. +""" + +import os +import redis +from dotenv import load_dotenv +import logging + +# Load environment variables +load_dotenv() + +logger = logging.getLogger(__name__) + + +def get_redis_config(): + """ + Get Redis configuration from environment variables. + + Returns: + dict: Redis configuration parameters + """ + return { + "host": os.getenv("REDIS_HOST", "localhost"), + "port": int(os.getenv("REDIS_PORT", 6379)), + "db": int(os.getenv("REDIS_DB", 0)), + "password": os.getenv("REDIS_PASSWORD", None), + "ssl": os.getenv("REDIS_SSL", "false").lower() == "true", + "key_prefix": os.getenv("REDIS_KEY_PREFIX", "a2a:"), + "default_ttl": int(os.getenv("REDIS_TTL", 3600)), + } + + +def get_a2a_config(): + """ + Get A2A-specific cache TTL values from environment variables. + + Returns: + dict: A2A TTL configuration parameters + """ + return { + "task_ttl": int(os.getenv("A2A_TASK_TTL", 3600)), + "history_ttl": int(os.getenv("A2A_HISTORY_TTL", 86400)), + "push_notification_ttl": int(os.getenv("A2A_PUSH_NOTIFICATION_TTL", 3600)), + "sse_client_ttl": int(os.getenv("A2A_SSE_CLIENT_TTL", 300)), + } + + +def create_redis_pool(config=None): + """ + Create and return a Redis connection pool. + + Args: + config (dict, optional): Redis configuration. If None, + configuration is loaded from environment + + Returns: + redis.ConnectionPool: Redis connection pool + """ + if config is None: + config = get_redis_config() + + try: + connection_pool = redis.ConnectionPool( + host=config["host"], + port=config["port"], + db=config["db"], + password=config["password"] if config["password"] else None, + ssl=config["ssl"], + decode_responses=True, + ) + # Test the connection + redis_client = redis.Redis(connection_pool=connection_pool) + redis_client.ping() + logger.info( + f"Redis connection successful: {config['host']}:{config['port']}, " + f"db={config['db']}, ssl={config['ssl']}" + ) + return connection_pool + except redis.RedisError as e: + logger.error(f"Redis connection error: {e}") + raise diff --git a/src/config/settings.py b/src/config/settings.py index 068f05d0..c5bb76ec 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -37,6 +37,9 @@ class Settings(BaseSettings): REDIS_PORT: int = int(os.getenv("REDIS_PORT", 6379)) REDIS_DB: int = int(os.getenv("REDIS_DB", 0)) REDIS_PASSWORD: Optional[str] = os.getenv("REDIS_PASSWORD") + REDIS_SSL: bool = os.getenv("REDIS_SSL", "false").lower() == "true" + REDIS_KEY_PREFIX: str = os.getenv("REDIS_KEY_PREFIX", "evoai:") + REDIS_TTL: int = int(os.getenv("REDIS_TTL", 3600)) # Tool cache TTL in seconds (1 hour) TOOLS_CACHE_TTL: int = int(os.getenv("TOOLS_CACHE_TTL", 3600)) diff --git a/src/main.py b/src/main.py index 6d862b63..327b6ec8 100644 --- a/src/main.py +++ b/src/main.py @@ -22,6 +22,7 @@ import src.api.contact_routes import src.api.mcp_server_routes import src.api.tool_routes import src.api.client_routes +import src.api.a2a_routes # Add the root directory to PYTHONPATH root_dir = Path(__file__).parent.parent @@ -40,13 +41,13 @@ app = FastAPI( # CORS configuration app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Permite todas as origens em desenvolvimento + allow_origins=["*"], # Allows all origins in development allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) -# Configuração de arquivos estáticos +# Static files configuration static_dir = Path("static") if not static_dir.exists(): static_dir.mkdir(parents=True) @@ -72,6 +73,7 @@ contact_router = src.api.contact_routes.router mcp_server_router = src.api.mcp_server_routes.router tool_router = src.api.tool_routes.router client_router = src.api.client_routes.router +a2a_router = src.api.a2a_routes.router # Include routes app.include_router(auth_router, prefix=API_PREFIX) @@ -83,6 +85,7 @@ app.include_router(chat_router, prefix=API_PREFIX) app.include_router(session_router, prefix=API_PREFIX) app.include_router(agent_router, prefix=API_PREFIX) app.include_router(contact_router, prefix=API_PREFIX) +app.include_router(a2a_router, prefix=API_PREFIX) @app.get("/") diff --git a/src/schemas/a2a/__init__.py b/src/schemas/a2a/__init__.py new file mode 100644 index 00000000..68293cc2 --- /dev/null +++ b/src/schemas/a2a/__init__.py @@ -0,0 +1,9 @@ +""" +A2A (Agent-to-Agent) schema package. + +This package contains Pydantic schema definitions for the A2A protocol. +""" + +from src.schemas.a2a.types import * +from src.schemas.a2a.exceptions import * +from src.schemas.a2a.validators import * diff --git a/src/schemas/a2a/exceptions.py b/src/schemas/a2a/exceptions.py new file mode 100644 index 00000000..19dce091 --- /dev/null +++ b/src/schemas/a2a/exceptions.py @@ -0,0 +1,147 @@ +""" +A2A (Agent-to-Agent) protocol exception definitions. + +This module contains error types and exceptions for the A2A protocol. +""" + +from src.schemas.a2a.types import JSONRPCError + + +class JSONParseError(JSONRPCError): + """ + Error raised when JSON parsing fails. + """ + + code: int = -32700 + message: str = "Invalid JSON payload" + data: object | None = None + + +class InvalidRequestError(JSONRPCError): + """ + Error raised when request validation fails. + """ + + code: int = -32600 + message: str = "Request payload validation error" + data: object | None = None + + +class MethodNotFoundError(JSONRPCError): + """ + Error raised when the requested method is not found. + """ + + code: int = -32601 + message: str = "Method not found" + data: None = None + + +class InvalidParamsError(JSONRPCError): + """ + Error raised when the parameters are invalid. + """ + + code: int = -32602 + message: str = "Invalid parameters" + data: object | None = None + + +class InternalError(JSONRPCError): + """ + Error raised when an internal error occurs. + """ + + code: int = -32603 + message: str = "Internal error" + data: object | None = None + + +class TaskNotFoundError(JSONRPCError): + """ + Error raised when the requested task is not found. + """ + + code: int = -32001 + message: str = "Task not found" + data: None = None + + +class TaskNotCancelableError(JSONRPCError): + """ + Error raised when a task cannot be canceled. + """ + + code: int = -32002 + message: str = "Task cannot be canceled" + data: None = None + + +class PushNotificationNotSupportedError(JSONRPCError): + """ + Error raised when push notifications are not supported. + """ + + code: int = -32003 + message: str = "Push Notification is not supported" + data: None = None + + +class UnsupportedOperationError(JSONRPCError): + """ + Error raised when an operation is not supported. + """ + + code: int = -32004 + message: str = "This operation is not supported" + data: None = None + + +class ContentTypeNotSupportedError(JSONRPCError): + """ + Error raised when content types are incompatible. + """ + + code: int = -32005 + message: str = "Incompatible content types" + data: None = None + + +# Client exceptions + + +class A2AClientError(Exception): + """ + Base exception for A2A client errors. + """ + + pass + + +class A2AClientHTTPError(A2AClientError): + """ + Exception for HTTP errors in A2A client. + """ + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + super().__init__(f"HTTP Error {status_code}: {message}") + + +class A2AClientJSONError(A2AClientError): + """ + Exception for JSON errors in A2A client. + """ + + def __init__(self, message: str): + self.message = message + super().__init__(f"JSON Error: {message}") + + +class MissingAPIKeyError(Exception): + """ + Exception for missing API key. + """ + + pass diff --git a/src/schemas/a2a/types.py b/src/schemas/a2a/types.py new file mode 100644 index 00000000..9efcb7b3 --- /dev/null +++ b/src/schemas/a2a/types.py @@ -0,0 +1,464 @@ +""" +A2A (Agent-to-Agent) protocol type definitions. + +This module contains Pydantic schema definitions for the A2A protocol. +""" + +from typing import Union, Any, List, Optional, Annotated, Literal +from pydantic import ( + BaseModel, + Field, + TypeAdapter, + field_serializer, + model_validator, + ConfigDict, +) +from datetime import datetime +from uuid import uuid4 +from enum import Enum +from typing_extensions import Self + + +class TaskState(str, Enum): + """ + Enum for the state of a task in the A2A protocol. + + States follow the A2A protocol specification. + """ + + SUBMITTED = "submitted" + WORKING = "working" + INPUT_REQUIRED = "input-required" + COMPLETED = "completed" + CANCELED = "canceled" + FAILED = "failed" + UNKNOWN = "unknown" + + +class TextPart(BaseModel): + """ + Represents a text part in a message. + """ + + type: Literal["text"] = "text" + text: str + metadata: dict[str, Any] | None = None + + +class FileContent(BaseModel): + """ + Represents file content in a file part. + + Either bytes or uri must be provided, but not both. + """ + + name: str | None = None + mimeType: str | None = None + bytes: str | None = None + uri: str | None = None + + @model_validator(mode="after") + def check_content(self) -> Self: + """ + Validates that either bytes or uri is present, but not both. + """ + if not (self.bytes or self.uri): + raise ValueError("Either 'bytes' or 'uri' must be present in the file data") + if self.bytes and self.uri: + raise ValueError( + "Only one of 'bytes' or 'uri' can be present in the file data" + ) + return self + + +class FilePart(BaseModel): + """ + Represents a file part in a message. + """ + + type: Literal["file"] = "file" + file: FileContent + metadata: dict[str, Any] | None = None + + +class DataPart(BaseModel): + """ + Represents a data part in a message. + """ + + type: Literal["data"] = "data" + data: dict[str, Any] + metadata: dict[str, Any] | None = None + + +Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator="type")] + + +class Message(BaseModel): + """ + Represents a message in the A2A protocol. + + A message consists of a role and one or more parts. + """ + + role: Literal["user", "agent"] + parts: List[Part] + metadata: dict[str, Any] | None = None + + +class TaskStatus(BaseModel): + """ + Represents the status of a task. + """ + + state: TaskState + message: Message | None = None + timestamp: datetime = Field(default_factory=datetime.now) + + @field_serializer("timestamp") + def serialize_dt(self, dt: datetime, _info): + """ + Serializes datetime to ISO format. + """ + return dt.isoformat() + + +class Artifact(BaseModel): + """ + Represents an artifact produced by an agent. + """ + + name: str | None = None + description: str | None = None + parts: List[Part] + metadata: dict[str, Any] | None = None + index: int = 0 + append: bool | None = None + lastChunk: bool | None = None + + +class Task(BaseModel): + """ + Represents a task in the A2A protocol. + """ + + id: str + sessionId: str | None = None + status: TaskStatus + artifacts: List[Artifact] | None = None + history: List[Message] | None = None + metadata: dict[str, Any] | None = None + + +class TaskStatusUpdateEvent(BaseModel): + """ + Represents a task status update event. + """ + + id: str + status: TaskStatus + final: bool = False + metadata: dict[str, Any] | None = None + + +class TaskArtifactUpdateEvent(BaseModel): + """ + Represents a task artifact update event. + """ + + id: str + artifact: Artifact + metadata: dict[str, Any] | None = None + + +class AuthenticationInfo(BaseModel): + """ + Represents authentication information for push notifications. + """ + + model_config = ConfigDict(extra="allow") + + schemes: List[str] + credentials: str | None = None + + +class PushNotificationConfig(BaseModel): + """ + Represents push notification configuration. + """ + + url: str + token: str | None = None + authentication: AuthenticationInfo | None = None + + +class TaskIdParams(BaseModel): + """ + Represents parameters for identifying a task. + """ + + id: str + metadata: dict[str, Any] | None = None + + +class TaskQueryParams(TaskIdParams): + """ + Represents parameters for querying a task. + """ + + historyLength: int | None = None + + +class TaskSendParams(BaseModel): + """ + Represents parameters for sending a task. + """ + + id: str + sessionId: str = Field(default_factory=lambda: uuid4().hex) + message: Message + acceptedOutputModes: Optional[List[str]] = None + pushNotification: PushNotificationConfig | None = None + historyLength: int | None = None + metadata: dict[str, Any] | None = None + agentId: str = "" + + +class TaskPushNotificationConfig(BaseModel): + """ + Represents push notification configuration for a task. + """ + + id: str + pushNotificationConfig: PushNotificationConfig + + +# RPC Messages + + +class JSONRPCMessage(BaseModel): + """ + Base class for JSON-RPC messages. + """ + + jsonrpc: Literal["2.0"] = "2.0" + id: int | str | None = Field(default_factory=lambda: uuid4().hex) + + +class JSONRPCRequest(JSONRPCMessage): + """ + Represents a JSON-RPC request. + """ + + method: str + params: dict[str, Any] | None = None + + +class JSONRPCError(BaseModel): + """ + Represents a JSON-RPC error. + """ + + code: int + message: str + data: Any | None = None + + +class JSONRPCResponse(JSONRPCMessage): + """ + Represents a JSON-RPC response. + """ + + result: Any | None = None + error: JSONRPCError | None = None + + +class SendTaskRequest(JSONRPCRequest): + """ + Represents a request to send a task. + """ + + method: Literal["tasks/send"] = "tasks/send" + params: TaskSendParams + + +class SendTaskResponse(JSONRPCResponse): + """ + Represents a response to a send task request. + """ + + result: Task | None = None + + +class SendTaskStreamingRequest(JSONRPCRequest): + """ + Represents a request to send a task with streaming. + """ + + method: Literal["tasks/sendSubscribe"] = "tasks/sendSubscribe" + params: TaskSendParams + + +class SendTaskStreamingResponse(JSONRPCResponse): + """ + Represents a streaming response to a send task request. + """ + + result: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None + + +class GetTaskRequest(JSONRPCRequest): + """ + Represents a request to get task information. + """ + + method: Literal["tasks/get"] = "tasks/get" + params: TaskQueryParams + + +class GetTaskResponse(JSONRPCResponse): + """ + Represents a response to a get task request. + """ + + result: Task | None = None + + +class CancelTaskRequest(JSONRPCRequest): + """ + Represents a request to cancel a task. + """ + + method: Literal["tasks/cancel",] = "tasks/cancel" + params: TaskIdParams + + +class CancelTaskResponse(JSONRPCResponse): + """ + Represents a response to a cancel task request. + """ + + result: Task | None = None + + +class SetTaskPushNotificationRequest(JSONRPCRequest): + """ + Represents a request to set push notification for a task. + """ + + method: Literal["tasks/pushNotification/set",] = "tasks/pushNotification/set" + params: TaskPushNotificationConfig + + +class SetTaskPushNotificationResponse(JSONRPCResponse): + """ + Represents a response to a set push notification request. + """ + + result: TaskPushNotificationConfig | None = None + + +class GetTaskPushNotificationRequest(JSONRPCRequest): + """ + Represents a request to get push notification configuration for a task. + """ + + method: Literal["tasks/pushNotification/get",] = "tasks/pushNotification/get" + params: TaskIdParams + + +class GetTaskPushNotificationResponse(JSONRPCResponse): + """ + Represents a response to a get push notification request. + """ + + result: TaskPushNotificationConfig | None = None + + +class TaskResubscriptionRequest(JSONRPCRequest): + """ + Represents a request to resubscribe to a task. + """ + + method: Literal["tasks/resubscribe",] = "tasks/resubscribe" + params: TaskIdParams + + +# TypeAdapter for discriminating A2A requests by method +A2ARequest = TypeAdapter( + Annotated[ + Union[ + SendTaskRequest, + GetTaskRequest, + CancelTaskRequest, + SetTaskPushNotificationRequest, + GetTaskPushNotificationRequest, + TaskResubscriptionRequest, + SendTaskStreamingRequest, + ], + Field(discriminator="method"), + ] +) + + +# Agent Card schemas + + +class AgentProvider(BaseModel): + """ + Represents the provider of an agent. + """ + + organization: str + url: str | None = None + + +class AgentCapabilities(BaseModel): + """ + Represents the capabilities of an agent. + """ + + streaming: bool = False + pushNotifications: bool = False + stateTransitionHistory: bool = False + + +class AgentAuthentication(BaseModel): + """ + Represents the authentication requirements for an agent. + """ + + schemes: List[str] + credentials: str | None = None + + +class AgentSkill(BaseModel): + """ + Represents a skill of an agent. + """ + + id: str + name: str + description: str | None = None + tags: List[str] | None = None + examples: List[str] | None = None + inputModes: List[str] | None = None + outputModes: List[str] | None = None + + +class AgentCard(BaseModel): + """ + Represents an agent card in the A2A protocol. + """ + + name: str + description: str | None = None + url: str + provider: AgentProvider | None = None + version: str + documentationUrl: str | None = None + capabilities: AgentCapabilities + authentication: AgentAuthentication | None = None + defaultInputModes: List[str] = ["text"] + defaultOutputModes: List[str] = ["text"] + skills: List[AgentSkill] diff --git a/src/schemas/a2a/validators.py b/src/schemas/a2a/validators.py new file mode 100644 index 00000000..9ee32811 --- /dev/null +++ b/src/schemas/a2a/validators.py @@ -0,0 +1,124 @@ +""" +A2A (Agent-to-Agent) protocol validators. + +This module contains validators for the A2A protocol data. +""" + +from typing import List +import base64 +import re +from pydantic import ValidationError +import logging +from src.schemas.a2a.types import Part, TextPart, FilePart, DataPart, FileContent + +logger = logging.getLogger(__name__) + + +def validate_base64(value: str) -> bool: + """ + Validates if a string is valid base64. + + Args: + value: String to validate + + Returns: + True if valid base64, False otherwise + """ + try: + if not value: + return False + + # Check if the string has base64 characters only + pattern = r"^[A-Za-z0-9+/]+={0,2}$" + if not re.match(pattern, value): + return False + + # Try to decode + base64.b64decode(value) + return True + except Exception as e: + logger.warning(f"Invalid base64 string: {e}") + return False + + +def validate_file_content(file_content: FileContent) -> bool: + """ + Validates file content. + + Args: + file_content: FileContent to validate + + Returns: + True if valid, False otherwise + """ + try: + if file_content.bytes is not None: + return validate_base64(file_content.bytes) + elif file_content.uri is not None: + # Basic URL validation + pattern = r"^https?://.+" + return bool(re.match(pattern, file_content.uri)) + return False + except Exception as e: + logger.warning(f"Invalid file content: {e}") + return False + + +def validate_message_parts(parts: List[Part]) -> bool: + """ + Validates all parts in a message. + + Args: + parts: List of parts to validate + + Returns: + True if all parts are valid, False otherwise + """ + try: + for part in parts: + if isinstance(part, TextPart): + if not part.text or not isinstance(part.text, str): + return False + elif isinstance(part, FilePart): + if not validate_file_content(part.file): + return False + elif isinstance(part, DataPart): + if not part.data or not isinstance(part.data, dict): + return False + else: + return False + return True + except (ValidationError, Exception) as e: + logger.warning(f"Invalid message parts: {e}") + return False + + +def text_to_parts(text: str) -> List[Part]: + """ + Converts a plain text to a list of message parts. + + Args: + text: Plain text to convert + + Returns: + List containing a single TextPart + """ + return [TextPart(text=text)] + + +def parts_to_text(parts: List[Part]) -> str: + """ + Extracts text from a list of message parts. + + Args: + parts: List of parts to extract text from + + Returns: + Concatenated text from all text parts + """ + text = "" + for part in parts: + if isinstance(part, TextPart): + text += part.text + # Could add handling for other part types here + return text diff --git a/src/services/__init__.py b/src/services/__init__.py index 255943f4..9434fe1e 100644 --- a/src/services/__init__.py +++ b/src/services/__init__.py @@ -1 +1,9 @@ from .agent_runner import run_agent +from .redis_cache_service import RedisCacheService +from .a2a_task_manager_service import A2ATaskManager +from .a2a_server_service import A2AServer +from .a2a_integration_service import ( + AgentRunnerAdapter, + StreamingServiceAdapter, + create_agent_card_from_agent, +) diff --git a/src/services/a2a_integration_service.py b/src/services/a2a_integration_service.py new file mode 100644 index 00000000..b65f3694 --- /dev/null +++ b/src/services/a2a_integration_service.py @@ -0,0 +1,520 @@ +""" +A2A Integration Service. + +This service provides adapters to integrate existing services with the A2A protocol. +""" + +import json +import logging +import uuid +from datetime import datetime +from typing import Any, Dict, List, Optional, AsyncIterable + +from src.schemas.a2a import ( + AgentCard, + AgentCapabilities, + AgentProvider, + Artifact, + Message, + TaskArtifactUpdateEvent, + TaskStatus, + TaskStatusUpdateEvent, + TextPart, +) + +logger = logging.getLogger(__name__) + + +class AgentRunnerAdapter: + """ + Adapter for integrating the existing agent runner with the A2A protocol. + """ + + def __init__( + self, + agent_runner_func, + session_service=None, + artifacts_service=None, + memory_service=None, + db=None, + ): + """ + Initialize the adapter. + + Args: + agent_runner_func: The agent runner function (e.g., run_agent) + session_service: Session service for message history + artifacts_service: Artifacts service for artifact history + memory_service: Memory service for agent memory + db: Database session + """ + self.agent_runner_func = agent_runner_func + self.session_service = session_service + self.artifacts_service = artifacts_service + self.memory_service = memory_service + self.db = db + + async def get_supported_modes(self) -> List[str]: + """ + Get the supported output modes for the agent. + + Returns: + List of supported output modes + """ + # Default modes, can be extended based on agent configuration + return ["text", "application/json"] + + async def run_agent( + self, + agent_id: str, + message: str, + session_id: Optional[str] = None, + task_id: Optional[str] = None, + db=None, + ) -> Dict[str, Any]: + """ + Run the agent with the given message. + + Args: + agent_id: ID of the agent to run + message: User message to process + session_id: Optional session ID for conversation context + task_id: Optional task ID for tracking + db: Database session + + Returns: + Dictionary with the agent's response + """ + logger.info( + f"[AGENT-RUNNER] run_agent iniciado - agent_id={agent_id}, task_id={task_id}, session_id={session_id}" + ) + logger.info( + f"[AGENT-RUNNER] run_agent - message: '{message[:50]}...' (truncado)" + ) + + try: + # Use the existing agent runner function + session_id = session_id or str(uuid.uuid4()) + task_id = task_id or str(uuid.uuid4()) + + # Use the provided db or fallback to self.db + db_session = db if db is not None else self.db + + if db_session is None: + logger.error( + f"[AGENT-RUNNER] No database session available. db={db}, self.db={self.db}" + ) + else: + logger.info( + f"[AGENT-RUNNER] Using database session: {type(db_session).__name__}" + ) + + logger.info( + f"[AGENT-RUNNER] Chamando agent_runner_func com agent_id={agent_id}, contact_id={task_id}" + ) + response_text = await self.agent_runner_func( + agent_id=agent_id, + contact_id=task_id, + message=message, + session_service=self.session_service, + artifacts_service=self.artifacts_service, + memory_service=self.memory_service, + db=db_session, + session_id=session_id, + ) + + logger.info( + f"[AGENT-RUNNER] run_agent concluído com sucesso para agent_id={agent_id}, task_id={task_id}" + ) + logger.info( + f"[AGENT-RUNNER] resposta: '{str(response_text)[:50]}...' (truncado)" + ) + + return { + "status": "success", + "content": response_text, + "session_id": session_id, + "task_id": task_id, + "timestamp": datetime.now().isoformat(), + } + + except Exception as e: + logger.error(f"[AGENT-RUNNER] Error running agent: {e}", exc_info=True) + return { + "status": "error", + "error": str(e), + "session_id": session_id, + "task_id": task_id, + "timestamp": datetime.now().isoformat(), + } + + async def cancel_task(self, task_id: str) -> bool: + """ + Cancel a running task. + + Args: + task_id: ID of the task to cancel + + Returns: + True if successfully canceled, False otherwise + """ + # Currently, the agent runner doesn't support cancellation + # This is a placeholder for future implementation + logger.warning(f"Task cancellation not implemented for task {task_id}") + return False + + +class StreamingServiceAdapter: + """ + Adapter for integrating the existing streaming service with the A2A protocol. + """ + + def __init__(self, streaming_service): + """ + Initialize the adapter. + + Args: + streaming_service: The streaming service instance + """ + self.streaming_service = streaming_service + + async def stream_agent_response( + self, + agent_id: str, + message: str, + api_key: str, + session_id: Optional[str] = None, + task_id: Optional[str] = None, + db=None, + ) -> AsyncIterable[str]: + """ + Stream the agent's response as A2A events. + + Args: + agent_id: ID of the agent + message: User message to process + api_key: API key for authentication + session_id: Optional session ID for conversation context + task_id: Optional task ID for tracking + db: Database session + + Yields: + A2A event objects as JSON strings for SSE (Server-Sent Events) + """ + task_id = task_id or str(uuid.uuid4()) + logger.info(f"Starting streaming response for task {task_id}") + + # Set working status event + working_status = TaskStatus( + state="working", + timestamp=datetime.now(), + message=Message( + role="agent", parts=[TextPart(text="Processing your request...")] + ), + ) + + status_event = TaskStatusUpdateEvent( + id=task_id, status=working_status, final=False + ) + # IMPORTANTE: Converter para string JSON para SSE + yield json.dumps(status_event.model_dump()) + + content_buffer = "" + final_sent = False + has_error = False + + # Stream from the existing streaming service + try: + logger.info(f"Setting up streaming for agent {agent_id}, task {task_id}") + # To streaming, we use task_id as contact_id + contact_id = task_id + + # Adicionar tratamento de heartbeat para manter conexão ativa + last_event_time = datetime.now() + heartbeat_interval = 20 # segundos + + async for event in self.streaming_service.send_task_streaming( + agent_id=agent_id, + api_key=api_key, + message=message, + contact_id=contact_id, + session_id=session_id, + db=db, + ): + # Atualizar timestamp do último evento + last_event_time = datetime.now() + + # Process the streaming event format + event_data = event.get("data", "{}") + try: + logger.info(f"Processing event data: {event_data[:100]}...") + data = json.loads(event_data) + + # Extract content + if "delta" in data and data["delta"].get("content"): + content = data["delta"]["content"] + content_buffer += content + logger.info(f"Received content chunk: {content[:50]}...") + + # Create artifact update event + artifact = Artifact( + name="response", + parts=[TextPart(text=content)], + index=0, + append=True, + lastChunk=False, + ) + + artifact_event = TaskArtifactUpdateEvent( + id=task_id, artifact=artifact + ) + # IMPORTANTE: Converter para string JSON para SSE + yield json.dumps(artifact_event.model_dump()) + + # Check if final event + if data.get("done", False) and not final_sent: + logger.info(f"Received final event for task {task_id}") + # Create completed status event + completed_status = TaskStatus( + state="completed", + timestamp=datetime.now(), + message=Message( + role="agent", + parts=[ + TextPart(text=content_buffer or "Task completed") + ], + ), + ) + + # Final artifact with full content + final_artifact = Artifact( + name="response", + parts=[TextPart(text=content_buffer)], + index=0, + append=False, + lastChunk=True, + ) + + # Send the final artifact + final_artifact_event = TaskArtifactUpdateEvent( + id=task_id, artifact=final_artifact + ) + # IMPORTANTE: Converter para string JSON para SSE + yield json.dumps(final_artifact_event.model_dump()) + + # Send the completed status + final_status_event = TaskStatusUpdateEvent( + id=task_id, + status=completed_status, + final=True, + ) + # IMPORTANTE: Converter para string JSON para SSE + yield json.dumps(final_status_event.model_dump()) + + final_sent = True + + except json.JSONDecodeError as e: + logger.warning( + f"Received non-JSON event data: {e}. Data: {event_data[:100]}..." + ) + # Handle non-JSON events - simply add to buffer as text + if isinstance(event_data, str): + content_buffer += event_data + + # Create artifact update event + artifact = Artifact( + name="response", + parts=[TextPart(text=event_data)], + index=0, + append=True, + lastChunk=False, + ) + + artifact_event = TaskArtifactUpdateEvent( + id=task_id, artifact=artifact + ) + # IMPORTANTE: Converter para string JSON para SSE + yield json.dumps(artifact_event.model_dump()) + elif isinstance(event_data, dict): + # Try to extract text from the dictionary + text_value = str(event_data) + content_buffer += text_value + + artifact = Artifact( + name="response", + parts=[TextPart(text=text_value)], + index=0, + append=True, + lastChunk=False, + ) + + artifact_event = TaskArtifactUpdateEvent( + id=task_id, artifact=artifact + ) + # IMPORTANTE: Converter para string JSON para SSE + yield json.dumps(artifact_event.model_dump()) + + # Enviar heartbeat/keep-alive para manter a conexão SSE aberta + now = datetime.now() + if (now - last_event_time).total_seconds() > heartbeat_interval: + logger.info(f"Sending heartbeat for task {task_id}") + # Enviando evento de keep-alive como um evento de status de "working" + working_heartbeat = TaskStatus( + state="working", + timestamp=now, + message=Message( + role="agent", parts=[TextPart(text="Still processing...")] + ), + ) + heartbeat_event = TaskStatusUpdateEvent( + id=task_id, status=working_heartbeat, final=False + ) + # IMPORTANTE: Converter para string JSON para SSE + yield json.dumps(heartbeat_event.model_dump()) + last_event_time = now + + # Ensure we send a final event if not already sent + if not final_sent: + logger.info( + f"Stream completed for task {task_id}, sending final status" + ) + # Create completed status event + completed_status = TaskStatus( + state="completed", + timestamp=datetime.now(), + message=Message( + role="agent", + parts=[TextPart(text=content_buffer or "Task completed")], + ), + ) + + # Send the completed status + final_event = TaskStatusUpdateEvent( + id=task_id, status=completed_status, final=True + ) + # IMPORTANTE: Converter para string JSON para SSE + yield json.dumps(final_event.model_dump()) + + except Exception as e: + has_error = True + logger.error(f"Error in streaming for task {task_id}: {e}", exc_info=True) + + # Create failed status event + failed_status = TaskStatus( + state="failed", + timestamp=datetime.now(), + message=Message( + role="agent", + parts=[ + TextPart( + text=f"Error during streaming: {str(e)}. Partial response: {content_buffer[:200] if content_buffer else 'No content received'}" + ) + ], + ), + ) + + error_event = TaskStatusUpdateEvent( + id=task_id, status=failed_status, final=True + ) + # IMPORTANTE: Converter para string JSON para SSE + yield json.dumps(error_event.model_dump()) + + finally: + # Garantir que enviamos um evento final para fechar a conexão corretamente + if not final_sent and not has_error: + logger.info(f"Stream finalizing for task {task_id} via finally block") + try: + # Create completed status event + completed_status = TaskStatus( + state="completed", + timestamp=datetime.now(), + message=Message( + role="agent", + parts=[ + TextPart( + text=content_buffer or "Task completed (forced end)" + ) + ], + ), + ) + + # Send the completed status + final_event = TaskStatusUpdateEvent( + id=task_id, status=completed_status, final=True + ) + # IMPORTANTE: Converter para string JSON para SSE + yield json.dumps(final_event.model_dump()) + except Exception as final_error: + logger.error( + f"Error sending final event in finally block: {final_error}" + ) + + logger.info(f"Streaming completed for task {task_id}") + + +def create_agent_card_from_agent(agent, db) -> AgentCard: + """ + Create an A2A agent card from an agent model. + + Args: + agent: The agent model from the database + db: Database session + + Returns: + A2A AgentCard object + """ + import os + from src.api.agent_routes import format_agent_tools + import asyncio + + # Extract agent configuration + agent_config = agent.config + has_streaming = True # Assuming streaming is always supported + has_push = True # Assuming push notifications are supported + + # Format tools as skills + try: + # We use a different approach to handle the asynchronous function + mcp_servers = agent_config.get("mcp_servers", []) + + # We create a new thread to execute the asynchronous function + import concurrent.futures + import functools + + def run_async(coro): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + result = loop.run_until_complete(coro) + loop.close() + return result + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_async, format_agent_tools(mcp_servers, db)) + skills = future.result() + except Exception as e: + logger.error(f"Error formatting agent tools: {e}") + skills = [] + + # Create agent card + return AgentCard( + name=agent.name, + description=agent.description, + url=f"{os.getenv('API_URL', '')}/api/v1/a2a/{agent.id}", + provider=AgentProvider( + organization=os.getenv("ORGANIZATION_NAME", ""), + url=os.getenv("ORGANIZATION_URL", ""), + ), + version=os.getenv("API_VERSION", "1.0.0"), + capabilities=AgentCapabilities( + streaming=has_streaming, + pushNotifications=has_push, + stateTransitionHistory=True, + ), + authentication={ + "schemes": ["apiKey"], + "credentials": "x-api-key", + }, + defaultInputModes=["text", "application/json"], + defaultOutputModes=["text", "application/json"], + skills=skills, + ) diff --git a/src/services/a2a_server_service.py b/src/services/a2a_server_service.py new file mode 100644 index 00000000..acc6511f --- /dev/null +++ b/src/services/a2a_server_service.py @@ -0,0 +1,739 @@ +""" +Server A2A and task manager for the A2A protocol. + +This module implements a JSON-RPC compatible server for the A2A protocol, +that manages agent tasks, streaming events and push notifications. +""" + +import asyncio +import json +import logging +import uuid +from datetime import datetime +from typing import ( + Any, + Dict, + List, + Optional, + AsyncGenerator, + Callable, + Union, + AsyncIterable, +) +import httpx +from fastapi import Request +from fastapi.responses import JSONResponse, StreamingResponse, Response +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + +from src.schemas.a2a.types import A2ARequest +from src.services.agent_runner import run_agent +from src.services.a2a_integration_service import ( + AgentRunnerAdapter, + StreamingServiceAdapter, +) +from src.services.session_service import get_session_events +from src.services.redis_cache_service import RedisCacheService +from src.schemas.a2a.types import ( + SendTaskRequest, + SendTaskStreamingRequest, + GetTaskRequest, + CancelTaskRequest, + SetTaskPushNotificationRequest, + GetTaskPushNotificationRequest, + TaskResubscriptionRequest, + TaskSendParams, +) +from src.utils.a2a_utils import are_modalities_compatible + +logger = logging.getLogger(__name__) + + +class A2ATaskManager: + """ + Task manager for the A2A protocol. + + This class manages the lifecycle of tasks, including: + - Task execution + - Streaming of events + - Push notifications + - Status querying + - Cancellation + """ + + def __init__( + self, + redis_cache: RedisCacheService, + agent_runner: AgentRunnerAdapter, + streaming_service: StreamingServiceAdapter, + push_notification_service: Any = None, + ): + """ + Initialize the task manager. + + Args: + redis_cache: Cache service for storing task data + agent_runner: Adapter for agent execution + streaming_service: Adapter for event streaming + push_notification_service: Service for sending push notifications + """ + self.cache = redis_cache + self.agent_runner = agent_runner + self.streaming_service = streaming_service + self.push_notification_service = push_notification_service + self._running_tasks = {} + + async def on_send_task( + self, + task_id: str, + agent_id: str, + message: Dict[str, Any], + session_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + input_mode: str = "text", + output_modes: List[str] = ["text"], + db: Optional[Session] = None, + ) -> Dict[str, Any]: + """ + Process a request to send a task. + + Args: + task_id: Task ID + agent_id: Agent ID + message: User message + session_id: Session ID (optional) + metadata: Additional metadata (optional) + input_mode: Input mode (text, JSON, etc.) + output_modes: Supported output modes + db: Database session + + Returns: + Response with task result + """ + if not session_id: + session_id = f"{task_id}_{agent_id}" + + if not metadata: + metadata = {} + + # Update status to "submitted" + task_data = { + "id": task_id, + "sessionId": session_id, + "status": { + "state": "submitted", + "timestamp": datetime.now().isoformat(), + "message": None, + "error": None, + }, + "artifacts": [], + "history": [], + "metadata": metadata, + } + + # Store initial task data + await self.cache.set(f"task:{task_id}", task_data) + + # Check for push notification configurations + push_config = await self.cache.get(f"task:{task_id}:push") + if push_config and self.push_notification_service: + # Send initial notification + await self.push_notification_service.send_notification( + url=push_config["url"], + task_id=task_id, + state="submitted", + headers=push_config.get("headers", {}), + ) + + try: + # Update status to "running" + task_data["status"].update( + {"state": "running", "timestamp": datetime.now().isoformat()} + ) + await self.cache.set(f"task:{task_id}", task_data) + + # Notify "running" state + if push_config and self.push_notification_service: + await self.push_notification_service.send_notification( + url=push_config["url"], + task_id=task_id, + state="running", + headers=push_config.get("headers", {}), + ) + + # Extract user message + user_message = None + try: + user_message = message["parts"][0]["text"] + except (KeyError, IndexError): + user_message = "" + + # Execute the agent + response = await self.agent_runner.run_agent( + agent_id=agent_id, + task_id=task_id, + message=user_message, + session_id=session_id, + db=db, + ) + + # Check if the response is a dictionary (error) or a string (success) + if isinstance(response, dict) and response.get("status") == "error": + # Error response + final_response = f"Error: {response.get('error', 'Unknown error')}" + + # Update status to "failed" + task_data["status"].update( + { + "state": "failed", + "timestamp": datetime.now().isoformat(), + "error": { + "code": "AGENT_EXECUTION_ERROR", + "message": response.get("error", "Unknown error"), + }, + "message": { + "role": "system", + "parts": [{"type": "text", "text": final_response}], + }, + } + ) + + # Notify "failed" state + if push_config and self.push_notification_service: + await self.push_notification_service.send_notification( + url=push_config["url"], + task_id=task_id, + state="failed", + message={ + "role": "system", + "parts": [{"type": "text", "text": final_response}], + }, + headers=push_config.get("headers", {}), + ) + else: + # Success response + final_response = ( + response.get("content") if isinstance(response, dict) else response + ) + + # Update status to "completed" + task_data["status"].update( + { + "state": "completed", + "timestamp": datetime.now().isoformat(), + "message": { + "role": "agent", + "parts": [{"type": "text", "text": final_response}], + }, + } + ) + + # Add artifacts + if final_response: + task_data["artifacts"].append( + { + "type": "text", + "content": final_response, + "metadata": { + "generated_at": datetime.now().isoformat(), + "content_type": "text/plain", + }, + } + ) + + # Add history of messages + history_length = metadata.get("historyLength", 50) + try: + history_messages = get_session_events( + self.agent_runner.session_service, session_id + ) + history_messages = history_messages[-history_length:] + + formatted_history = [] + for event in history_messages: + if event.content and event.content.parts: + role = ( + "agent" + if event.content.role == "model" + else event.content.role + ) + formatted_history.append( + { + "role": role, + "parts": [ + {"type": "text", "text": part.text} + for part in event.content.parts + if part.text + ], + } + ) + + task_data["history"] = formatted_history + except Exception as e: + logger.error(f"Error processing history: {str(e)}") + + # Notify "completed" state + if push_config and self.push_notification_service: + await self.push_notification_service.send_notification( + url=push_config["url"], + task_id=task_id, + state="completed", + message={ + "role": "agent", + "parts": [{"type": "text", "text": final_response}], + }, + headers=push_config.get("headers", {}), + ) + + except Exception as e: + logger.error(f"Error executing task {task_id}: {str(e)}") + + # Update status to "failed" + task_data["status"].update( + { + "state": "failed", + "timestamp": datetime.now().isoformat(), + "error": {"code": "AGENT_EXECUTION_ERROR", "message": str(e)}, + } + ) + + # Notify "failed" state + if push_config and self.push_notification_service: + await self.push_notification_service.send_notification( + url=push_config["url"], + task_id=task_id, + state="failed", + message={ + "role": "system", + "parts": [{"type": "text", "text": str(e)}], + }, + headers=push_config.get("headers", {}), + ) + + # Store final result + await self.cache.set(f"task:{task_id}", task_data) + return task_data + + async def on_send_task_subscribe( + self, + task_id: str, + agent_id: str, + message: Dict[str, Any], + session_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + input_mode: str = "text", + output_modes: List[str] = ["text"], + db: Optional[Session] = None, + ) -> AsyncGenerator[str, None]: + """ + Process a request to send a task with streaming. + + Args: + task_id: Task ID + agent_id: Agent ID + message: User message + session_id: Session ID (optional) + metadata: Additional metadata (optional) + input_mode: Input mode (text, JSON, etc.) + output_modes: Supported output modes + db: Database session + + Yields: + Streaming events in SSE (Server-Sent Events) format + """ + if not session_id: + session_id = f"{task_id}_{agent_id}" + + if not metadata: + metadata = {} + + # Extract user message + user_message = "" + try: + user_message = message["parts"][0]["text"] + except (KeyError, IndexError): + pass + + # Generate streaming events + async for event in self.streaming_service.stream_response( + agent_id=agent_id, + task_id=task_id, + message=user_message, + session_id=session_id, + db=db, + ): + yield event + + async def on_get_task(self, task_id: str) -> Dict[str, Any]: + """ + Query the status of a task by ID. + + Args: + task_id: Task ID + + Returns: + Current task status + + Raises: + Exception: If the task is not found + """ + task_data = await self.cache.get(f"task:{task_id}") + if not task_data: + raise Exception(f"Task {task_id} not found") + return task_data + + async def on_cancel_task(self, task_id: str) -> Dict[str, Any]: + """ + Cancel a running task. + + Args: + task_id: Task ID to be cancelled + + Returns: + Task status after cancellation + + Raises: + Exception: If the task is not found or cannot be cancelled + """ + task_data = await self.cache.get(f"task:{task_id}") + if not task_data: + raise Exception(f"Task {task_id} not found") + + # Check if the task is in a state that can be cancelled + current_state = task_data["status"]["state"] + if current_state not in ["submitted", "running"]: + raise Exception(f"Cannot cancel task in {current_state} state") + + # Cancel the task in the runner if it is running + running_task = self._running_tasks.get(task_id) + if running_task: + # Try to cancel the running task + if hasattr(running_task, "cancel"): + running_task.cancel() + + # Update status to "cancelled" + task_data["status"].update( + { + "state": "cancelled", + "timestamp": datetime.now().isoformat(), + "message": { + "role": "system", + "parts": [{"type": "text", "text": "Task cancelled by user"}], + }, + } + ) + + # Update cache + await self.cache.set(f"task:{task_id}", task_data) + + # Send push notification if configured + push_config = await self.cache.get(f"task:{task_id}:push") + if push_config and self.push_notification_service: + await self.push_notification_service.send_notification( + url=push_config["url"], + task_id=task_id, + state="cancelled", + message={ + "role": "system", + "parts": [{"type": "text", "text": "Task cancelled by user"}], + }, + headers=push_config.get("headers", {}), + ) + + return task_data + + async def on_set_task_push_notification( + self, task_id: str, notification_config: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Configure push notifications for a task. + + Args: + task_id: Task ID + notification_config: Notification configuration (URL and headers) + + Returns: + Updated configuration + """ + # Validate configuration + url = notification_config.get("url") + if not url: + raise ValueError("Push notification URL is required") + + headers = notification_config.get("headers", {}) + + # Store configuration + config = {"url": url, "headers": headers} + await self.cache.set(f"task:{task_id}:push", config) + + return config + + async def on_get_task_push_notification(self, task_id: str) -> Dict[str, Any]: + """ + Get the push notification configuration for a task. + + Args: + task_id: Task ID + + Returns: + Push notification configuration + + Raises: + Exception: If there is no configuration for the task + """ + config = await self.cache.get(f"task:{task_id}:push") + if not config: + raise Exception(f"No push notification configuration for task {task_id}") + return config + + +class A2AServer: + """ + A2A server compatible with JSON-RPC 2.0. + + This class processes JSON-RPC requests and forwards them to + the appropriate handlers in the A2ATaskManager. + """ + + def __init__(self, task_manager: A2ATaskManager, agent_card=None): + """ + Initialize the A2A server. + + Args: + task_manager: Task manager + agent_card: Agent card information + """ + self.task_manager = task_manager + self.agent_card = agent_card + + async def process_request( + self, + request: Request, + agent_id: Optional[str] = None, + db: Optional[Session] = None, + ) -> Union[Response, JSONResponse, StreamingResponse]: + """ + Process a JSON-RPC request. + + Args: + request: HTTP request + agent_id: Optional agent ID to inject into the request + db: Database session + + Returns: + Appropriate response (JSON or Streaming) + """ + try: + # Try to parse the JSON payload + try: + logger.info("Starting JSON-RPC request processing") + body = await request.json() + logger.info(f"Received JSON data: {json.dumps(body)}") + method = body.get("method", "unknown") + logger.info(f"[SERVER] Processando método: {method}") + + # Validate the request using the A2A validator + json_rpc_request = A2ARequest.validate_python(body) + logger.info( + f"[SERVER] Request validado como: {type(json_rpc_request).__name__}" + ) + + original_db = self.task_manager.db + try: + # Set the db temporarily + if db is not None: + self.task_manager.db = db + + # Process the request + if isinstance(json_rpc_request, SendTaskRequest): + logger.info( + f"[SERVER] Processando SendTaskRequest para task_id={json_rpc_request.params.id}" + ) + json_rpc_request.params.agentId = agent_id + result = await self.task_manager.on_send_task(json_rpc_request) + elif isinstance(json_rpc_request, SendTaskStreamingRequest): + logger.info( + f"[SERVER] Processando SendTaskStreamingRequest para task_id={json_rpc_request.params.id}" + ) + json_rpc_request.params.agentId = agent_id + result = await self.task_manager.on_send_task_subscribe( + json_rpc_request + ) + elif isinstance(json_rpc_request, GetTaskRequest): + logger.info( + f"[SERVER] Processando GetTaskRequest para task_id={json_rpc_request.params.id}" + ) + result = await self.task_manager.on_get_task(json_rpc_request) + elif isinstance(json_rpc_request, CancelTaskRequest): + logger.info( + f"[SERVER] Processando CancelTaskRequest para task_id={json_rpc_request.params.id}" + ) + result = await self.task_manager.on_cancel_task( + json_rpc_request + ) + elif isinstance(json_rpc_request, SetTaskPushNotificationRequest): + logger.info( + f"[SERVER] Processando SetTaskPushNotificationRequest para task_id={json_rpc_request.params.id}" + ) + result = await self.task_manager.on_set_task_push_notification( + json_rpc_request + ) + elif isinstance(json_rpc_request, GetTaskPushNotificationRequest): + logger.info( + f"[SERVER] Processando GetTaskPushNotificationRequest para task_id={json_rpc_request.params.id}" + ) + result = await self.task_manager.on_get_task_push_notification( + json_rpc_request + ) + elif isinstance(json_rpc_request, TaskResubscriptionRequest): + logger.info( + f"[SERVER] Processando TaskResubscriptionRequest para task_id={json_rpc_request.params.id}" + ) + result = await self.task_manager.on_resubscribe_to_task( + json_rpc_request + ) + else: + logger.warning( + f"[SERVER] Tipo de request não suportado: {type(json_rpc_request)}" + ) + return JSONResponse( + status_code=400, + content={ + "jsonrpc": "2.0", + "id": body.get("id"), + "error": { + "code": -32601, + "message": "Method not found", + "data": {"detail": f"Method not supported"}, + }, + }, + ) + finally: + # Restore the original db + self.task_manager.db = original_db + + # Create appropriate response + return self._create_response(result) + + except json.JSONDecodeError as e: + # Error parsing JSON + logger.error(f"Error parsing JSON request: {str(e)}") + return JSONResponse( + status_code=400, + content={ + "jsonrpc": "2.0", + "id": None, + "error": { + "code": -32700, + "message": "Parse error", + "data": {"detail": str(e)}, + }, + }, + ) + except Exception as e: + # Other validation errors + logger.error(f"Error validating request: {str(e)}") + return JSONResponse( + status_code=400, + content={ + "jsonrpc": "2.0", + "id": body.get("id") if "body" in locals() else None, + "error": { + "code": -32600, + "message": "Invalid Request", + "data": {"detail": str(e)}, + }, + }, + ) + + except Exception as e: + logger.error(f"Error processing JSON-RPC request: {str(e)}", exc_info=True) + return JSONResponse( + status_code=500, + content={ + "jsonrpc": "2.0", + "id": None, + "error": { + "code": -32603, + "message": "Internal error", + "data": {"detail": str(e)}, + }, + }, + ) + + def _create_response(self, result: Any) -> Union[JSONResponse, StreamingResponse]: + """ + Create appropriate response based on result type. + + Args: + result: Result from task manager + + Returns: + JSON or streaming response + """ + if isinstance(result, AsyncIterable): + # Result in streaming (SSE) + async def event_generator(): + async for item in result: + if hasattr(item, "model_dump_json"): + yield {"data": item.model_dump_json(exclude_none=True)} + else: + yield {"data": json.dumps(item)} + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + elif hasattr(result, "model_dump"): + # Result is a Pydantic object + return JSONResponse(result.model_dump(exclude_none=True)) + + else: + # Result is a dictionary or other simple type + return JSONResponse(result) + + async def get_agent_card( + self, request: Request, db: Optional[Session] = None + ) -> JSONResponse: + """ + Get the agent card. + + Args: + request: HTTP request + db: Database session + + Returns: + Agent card as JSON + """ + if not self.agent_card: + logger.error("Agent card not configured") + return JSONResponse( + status_code=404, content={"error": "Agent card not configured"} + ) + + # If there is db, set it temporarily in the task_manager + if db and hasattr(self.task_manager, "db"): + original_db = self.task_manager.db + try: + self.task_manager.db = db + + # If it's a Pydantic object, convert to dictionary + if hasattr(self.agent_card, "model_dump"): + return JSONResponse(self.agent_card.model_dump(exclude_none=True)) + else: + return JSONResponse(self.agent_card) + finally: + # Restore the original db + self.task_manager.db = original_db + else: + # If it's a Pydantic object, convert to dictionary + if hasattr(self.agent_card, "model_dump"): + return JSONResponse(self.agent_card.model_dump(exclude_none=True)) + else: + return JSONResponse(self.agent_card) diff --git a/src/services/a2a_task_manager_service.py b/src/services/a2a_task_manager_service.py new file mode 100644 index 00000000..94f490c5 --- /dev/null +++ b/src/services/a2a_task_manager_service.py @@ -0,0 +1,888 @@ +""" +A2A Task Manager Service. + +This service implements task management for the A2A protocol, handling task lifecycle +including execution, streaming, push notifications, status queries, and cancellation. +""" + +import asyncio +import json +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional, Union, AsyncIterable + +from sqlalchemy.orm import Session + +from src.schemas.a2a.exceptions import ( + TaskNotFoundError, + TaskNotCancelableError, + PushNotificationNotSupportedError, + InternalError, + ContentTypeNotSupportedError, +) + +from src.schemas.a2a.types import ( + JSONRPCResponse, + TaskIdParams, + TaskQueryParams, + GetTaskRequest, + SendTaskRequest, + CancelTaskRequest, + SetTaskPushNotificationRequest, + GetTaskPushNotificationRequest, + GetTaskResponse, + CancelTaskResponse, + SendTaskResponse, + SetTaskPushNotificationResponse, + GetTaskPushNotificationResponse, + TaskSendParams, + TaskStatus, + TaskState, + TaskResubscriptionRequest, + SendTaskStreamingRequest, + SendTaskStreamingResponse, + Artifact, + PushNotificationConfig, + TaskStatusUpdateEvent, + TaskArtifactUpdateEvent, + JSONRPCError, + TaskPushNotificationConfig, + Message, + TextPart, + Task, +) +from src.services.redis_cache_service import RedisCacheService +from src.utils.a2a_utils import ( + are_modalities_compatible, + new_incompatible_types_error, + new_not_implemented_error, + create_task_id, + format_error_response, +) + +logger = logging.getLogger(__name__) + + +class A2ATaskManager: + """ + A2A Task Manager implementation. + + This class manages the lifecycle of A2A tasks, including: + - Task submission and execution + - Task status queries + - Task cancellation + - Push notification configuration + - SSE streaming for real-time updates + """ + + def __init__( + self, + redis_cache: RedisCacheService, + agent_runner=None, + streaming_service=None, + push_notification_service=None, + db=None, + ): + """ + Initialize the A2A Task Manager. + + Args: + redis_cache: Redis cache service for task storage + agent_runner: Agent runner service for task execution + streaming_service: Streaming service for SSE + push_notification_service: Service for push notifications + db: Database session + """ + self.redis_cache = redis_cache + self.agent_runner = agent_runner + self.streaming_service = streaming_service + self.push_notification_service = push_notification_service + self.db = db + self.lock = asyncio.Lock() + self.subscriber_lock = asyncio.Lock() + self.task_sse_subscribers = {} + + async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse: + """ + Manipula requisição para obter informações sobre uma tarefa. + + Args: + request: Requisição Get Task do A2A + + Returns: + Resposta com os detalhes da tarefa + """ + try: + task_id = request.params.id + history_length = request.params.historyLength + + # Busca dados da tarefa do cache + task_data = await self.redis_cache.get(f"task:{task_id}") + + if not task_data: + logger.warning(f"Tarefa não encontrada: {task_id}") + return GetTaskResponse(id=request.id, error=TaskNotFoundError()) + + # Cria uma instância Task a partir dos dados do cache + task = Task.model_validate(task_data) + + # Se o parâmetro historyLength estiver presente, manipula o histórico + if history_length is not None and task.history: + if history_length == 0: + task.history = [] + elif len(task.history) > history_length: + task.history = task.history[-history_length:] + + return GetTaskResponse(id=request.id, result=task) + except Exception as e: + logger.error(f"Erro ao processar on_get_task: {str(e)}") + return GetTaskResponse(id=request.id, error=InternalError(message=str(e))) + + async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: + """ + Handle request to cancel a running task. + + Args: + request: The JSON-RPC request to cancel a task + + Returns: + Response with updated task data or error + """ + logger.info(f"Cancelling task {request.params.id}") + task_id_params = request.params + + try: + task_data = await self.redis_cache.get(f"task:{task_id_params.id}") + if not task_data: + logger.warning(f"Task {task_id_params.id} not found for cancellation") + return CancelTaskResponse(id=request.id, error=TaskNotFoundError()) + + # Check if task can be cancelled + current_state = task_data.get("status", {}).get("state") + if current_state not in [TaskState.SUBMITTED, TaskState.WORKING]: + logger.warning( + f"Task {task_id_params.id} in state {current_state} cannot be cancelled" + ) + return CancelTaskResponse(id=request.id, error=TaskNotCancelableError()) + + # Update task status to cancelled + task_data["status"] = { + "state": TaskState.CANCELED, + "timestamp": datetime.now().isoformat(), + "message": { + "role": "agent", + "parts": [{"type": "text", "text": "Task cancelled by user"}], + }, + } + + # Save updated task data + await self.redis_cache.set(f"task:{task_id_params.id}", task_data) + + # Send push notification if configured + await self._send_push_notification_for_task( + task_id_params.id, "canceled", system_message="Task cancelled by user" + ) + + # Publish event to SSE subscribers + await self._publish_task_update( + task_id_params.id, + TaskStatusUpdateEvent( + id=task_id_params.id, + status=TaskStatus( + state=TaskState.CANCELED, + timestamp=datetime.now(), + message=Message( + role="agent", + parts=[TextPart(text="Task cancelled by user")], + ), + ), + final=True, + ), + ) + + return CancelTaskResponse(id=request.id, result=task_data) + + except Exception as e: + logger.error(f"Error cancelling task: {str(e)}", exc_info=True) + return CancelTaskResponse( + id=request.id, + error=InternalError(message=f"Error cancelling task: {str(e)}"), + ) + + async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: + """ + Manipula requisição para enviar uma nova tarefa. + + Args: + request: Requisição de envio de tarefa + + Returns: + Resposta com os detalhes da tarefa criada + """ + try: + params = request.params + task_id = params.id + logger.info(f"Recebendo tarefa {task_id}") + + # Verifica se já existe uma tarefa com esse ID + existing_task = await self.redis_cache.get(f"task:{task_id}") + if existing_task: + # Se a tarefa já existe e está em progresso, retorna a tarefa atual + if existing_task.get("status", {}).get("state") in [ + TaskState.WORKING, + TaskState.COMPLETED, + ]: + logger.info( + f"Tarefa {task_id} já existe e está em progresso/concluída" + ) + return SendTaskResponse( + id=request.id, result=Task.model_validate(existing_task) + ) + + # Se a tarefa existe mas falhou ou foi cancelada, podemos reprocessá-la + logger.info(f"Reprocessando tarefa existente {task_id}") + + # Verifica compatibilidade de modalidades + server_output_modes = [] + if self.agent_runner: + # Tenta obter modos suportados do agente + try: + server_output_modes = await self.agent_runner.get_supported_modes() + except Exception as e: + logger.warning(f"Erro ao obter modos suportados: {str(e)}") + server_output_modes = ["text"] # Fallback para texto + + if not are_modalities_compatible( + server_output_modes, params.acceptedOutputModes + ): + logger.warning( + f"Modos incompatíveis: servidor={server_output_modes}, cliente={params.acceptedOutputModes}" + ) + return SendTaskResponse( + id=request.id, error=ContentTypeNotSupportedError() + ) + + # Cria dados da tarefa + task_data = await self._create_task_data(params) + + # Armazena a tarefa no cache + await self.redis_cache.set(f"task:{task_id}", task_data) + + # Configura notificações push, se fornecidas + if params.pushNotification: + await self.redis_cache.set( + f"task_notification:{task_id}", params.pushNotification.model_dump() + ) + + # Inicia a execução da tarefa em background + asyncio.create_task(self._execute_task(task_data, params)) + + # Converte para objeto Task e retorna + task = Task.model_validate(task_data) + return SendTaskResponse(id=request.id, result=task) + + except Exception as e: + logger.error(f"Erro ao processar on_send_task: {str(e)}") + return SendTaskResponse(id=request.id, error=InternalError(message=str(e))) + + async def on_send_task_subscribe( + self, request: SendTaskStreamingRequest + ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: + """ + Handle request to send a task and subscribe to streaming updates. + + Args: + request: The JSON-RPC request to send a task with streaming + + Returns: + Stream of events or error response + """ + logger.info(f"Sending task with streaming {request.params.id}") + task_send_params = request.params + + try: + # Check output mode compatibility + if not are_modalities_compatible( + ["text", "application/json"], # Default supported modes + task_send_params.acceptedOutputModes, + ): + return new_incompatible_types_error(request.id) + + # Create initial task data + task_data = await self._create_task_data(task_send_params) + + # Setup SSE consumer + sse_queue = await self._setup_sse_consumer(task_send_params.id) + + # Execute task asynchronously (fire and forget) + asyncio.create_task(self._execute_task(task_data, task_send_params)) + + # Return generator to dequeue events for SSE + return self._dequeue_events_for_sse( + request.id, task_send_params.id, sse_queue + ) + + except Exception as e: + logger.error(f"Error setting up streaming task: {str(e)}", exc_info=True) + return SendTaskStreamingResponse( + id=request.id, + error=InternalError( + message=f"Error setting up streaming task: {str(e)}" + ), + ) + + async def on_set_task_push_notification( + self, request: SetTaskPushNotificationRequest + ) -> SetTaskPushNotificationResponse: + """ + Configure push notifications for a task. + + Args: + request: The JSON-RPC request to set push notification + + Returns: + Response with configuration or error + """ + logger.info(f"Setting push notification for task {request.params.id}") + task_notification_params = request.params + + try: + if not self.push_notification_service: + logger.warning("Push notifications not supported") + return SetTaskPushNotificationResponse( + id=request.id, error=PushNotificationNotSupportedError() + ) + + # Check if task exists + task_data = await self.redis_cache.get( + f"task:{task_notification_params.id}" + ) + if not task_data: + logger.warning( + f"Task {task_notification_params.id} not found for setting push notification" + ) + return SetTaskPushNotificationResponse( + id=request.id, error=TaskNotFoundError() + ) + + # Save push notification config + config = { + "url": task_notification_params.pushNotificationConfig.url, + "headers": {}, # Add auth headers if needed + } + + await self.redis_cache.set( + f"task:{task_notification_params.id}:push", config + ) + + return SetTaskPushNotificationResponse( + id=request.id, result=task_notification_params + ) + + except Exception as e: + logger.error(f"Error setting push notification: {str(e)}", exc_info=True) + return SetTaskPushNotificationResponse( + id=request.id, + error=InternalError( + message=f"Error setting push notification: {str(e)}" + ), + ) + + async def on_get_task_push_notification( + self, request: GetTaskPushNotificationRequest + ) -> GetTaskPushNotificationResponse: + """ + Get push notification configuration for a task. + + Args: + request: The JSON-RPC request to get push notification config + + Returns: + Response with configuration or error + """ + logger.info(f"Getting push notification for task {request.params.id}") + task_params = request.params + + try: + if not self.push_notification_service: + logger.warning("Push notifications not supported") + return GetTaskPushNotificationResponse( + id=request.id, error=PushNotificationNotSupportedError() + ) + + # Check if task exists + task_data = await self.redis_cache.get(f"task:{task_params.id}") + if not task_data: + logger.warning( + f"Task {task_params.id} not found for getting push notification" + ) + return GetTaskPushNotificationResponse( + id=request.id, error=TaskNotFoundError() + ) + + # Get push notification config + config = await self.redis_cache.get(f"task:{task_params.id}:push") + if not config: + logger.warning(f"No push notification config for task {task_params.id}") + return GetTaskPushNotificationResponse( + id=request.id, + error=InternalError( + message="No push notification configuration found" + ), + ) + + result = TaskPushNotificationConfig( + id=task_params.id, + pushNotificationConfig=PushNotificationConfig( + url=config.get("url"), token=None, authentication=None + ), + ) + + return GetTaskPushNotificationResponse(id=request.id, result=result) + + except Exception as e: + logger.error(f"Error getting push notification: {str(e)}", exc_info=True) + return GetTaskPushNotificationResponse( + id=request.id, + error=InternalError( + message=f"Error getting push notification: {str(e)}" + ), + ) + + async def on_resubscribe_to_task( + self, request: TaskResubscriptionRequest + ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: + """ + Resubscribe to a task's streaming events. + + Args: + request: The JSON-RPC request to resubscribe + + Returns: + Stream of events or error response + """ + logger.info(f"Resubscribing to task {request.params.id}") + task_params = request.params + + try: + # Check if task exists + task_data = await self.redis_cache.get(f"task:{task_params.id}") + if not task_data: + logger.warning(f"Task {task_params.id} not found for resubscription") + return JSONRPCResponse(id=request.id, error=TaskNotFoundError()) + + # Setup SSE consumer with resubscribe flag + try: + sse_queue = await self._setup_sse_consumer( + task_params.id, is_resubscribe=True + ) + except ValueError: + logger.warning( + f"Task {task_params.id} not available for resubscription" + ) + return JSONRPCResponse( + id=request.id, + error=InternalError( + message="Task not available for resubscription" + ), + ) + + # Send initial status update to the new subscriber + status = task_data.get("status", {}) + final = status.get("state") in [ + TaskState.COMPLETED, + TaskState.FAILED, + TaskState.CANCELED, + ] + + # Convert to TaskStatus object + task_status = TaskStatus( + state=status.get("state", TaskState.UNKNOWN), + timestamp=datetime.fromisoformat( + status.get("timestamp", datetime.now().isoformat()) + ), + message=status.get("message"), + ) + + # Publish to the specific queue + await sse_queue.put( + TaskStatusUpdateEvent( + id=task_params.id, status=task_status, final=final + ) + ) + + # Return generator to dequeue events for SSE + return self._dequeue_events_for_sse(request.id, task_params.id, sse_queue) + + except Exception as e: + logger.error(f"Error resubscribing to task: {str(e)}", exc_info=True) + return JSONRPCResponse( + id=request.id, + error=InternalError(message=f"Error resubscribing to task: {str(e)}"), + ) + + async def _create_task_data(self, params: TaskSendParams) -> Dict[str, Any]: + """ + Create initial task data structure. + + Args: + params: Task send parameters + + Returns: + Task data dictionary + """ + # Create task with initial status + task_data = { + "id": params.id, + "sessionId": params.sessionId, + "status": { + "state": TaskState.SUBMITTED, + "timestamp": datetime.now().isoformat(), + "message": None, + "error": None, + }, + "artifacts": [], + "history": [params.message.model_dump()], + "metadata": params.metadata or {}, + } + + # Save task data + await self.redis_cache.set(f"task:{params.id}", task_data) + + return task_data + + async def _execute_task(self, task: Dict[str, Any], params: TaskSendParams) -> None: + """ + Executa uma tarefa usando o adaptador do agente. + + Esta função é responsável pela execução real da tarefa pelo agente, + atualizando seu status conforme o progresso. + + Args: + task: Dados da tarefa a ser executada + params: Parâmetros de envio da tarefa + """ + task_id = task["id"] + agent_id = params.agentId + message_text = "" + + # Extrai o texto da mensagem + if params.message and params.message.parts: + for part in params.message.parts: + if part.type == "text": + message_text += part.text + + if not message_text: + await self._update_task_status( + task_id, TaskState.FAILED, "Mensagem não contém texto", final=True + ) + return + + # Verificamos se é uma execução em andamento + task_status = task.get("status", {}) + if task_status.get("state") in [TaskState.WORKING, TaskState.COMPLETED]: + logger.info(f"Tarefa {task_id} já está em execução ou concluída") + return + + try: + # Atualiza para estado "working" + await self._update_task_status( + task_id, TaskState.WORKING, "Processando solicitação" + ) + + # Executa o agente + if self.agent_runner: + response = await self.agent_runner.run_agent( + agent_id=agent_id, + message=message_text, + session_id=params.sessionId, + task_id=task_id, + ) + + # Processa a resposta do agente + if response and isinstance(response, dict): + # Extrai texto da resposta + response_text = response.get("text", "") + if not response_text and "message" in response: + message = response.get("message", {}) + parts = message.get("parts", []) + for part in parts: + if part.get("type") == "text": + response_text += part.get("text", "") + + # Constrói a mensagem final do agente + if response_text: + # Cria um artefato para a resposta + artifact = Artifact( + name="response", + parts=[TextPart(text=response_text)], + index=0, + lastChunk=True, + ) + + # Adiciona o artefato à tarefa + await self._add_task_artifact(task_id, artifact) + + # Atualiza o status da tarefa para completado + await self._update_task_status( + task_id, TaskState.COMPLETED, response_text, final=True + ) + else: + await self._update_task_status( + task_id, + TaskState.FAILED, + "O agente não retornou uma resposta válida", + final=True, + ) + else: + await self._update_task_status( + task_id, + TaskState.FAILED, + "Resposta inválida do agente", + final=True, + ) + else: + await self._update_task_status( + task_id, + TaskState.FAILED, + "Adaptador do agente não configurado", + final=True, + ) + except Exception as e: + logger.error(f"Erro na execução da tarefa {task_id}: {str(e)}") + await self._update_task_status( + task_id, TaskState.FAILED, f"Erro ao processar: {str(e)}", final=True + ) + + async def _update_task_status( + self, task_id: str, state: TaskState, message_text: str, final: bool = False + ) -> None: + """ + Atualiza o status de uma tarefa. + + Args: + task_id: ID da tarefa a ser atualizada + state: Novo estado da tarefa + message_text: Texto da mensagem associada ao status + final: Indica se este é o status final da tarefa + """ + try: + # Busca dados atuais da tarefa + task_data = await self.redis_cache.get(f"task:{task_id}") + if not task_data: + logger.warning( + f"Não foi possível atualizar status: tarefa {task_id} não encontrada" + ) + return + + # Cria objeto de status com a mensagem + agent_message = Message( + role="agent", + parts=[TextPart(text=message_text)], + metadata={"timestamp": datetime.now().isoformat()}, + ) + + status = TaskStatus( + state=state, message=agent_message, timestamp=datetime.now() + ) + + # Atualiza o status na tarefa + task_data["status"] = status.model_dump(exclude_none=True) + + # Atualiza o histórico, se existir + if "history" not in task_data: + task_data["history"] = [] + + # Adiciona a mensagem ao histórico + task_data["history"].append(agent_message.model_dump(exclude_none=True)) + + # Armazena a tarefa atualizada + await self.redis_cache.set(f"task:{task_id}", task_data) + + # Cria evento de atualização de status + status_event = TaskStatusUpdateEvent(id=task_id, status=status, final=final) + + # Publica atualização + await self._publish_task_update(task_id, status_event) + + # Envia notificação push, se configurada + if final or state in [ + TaskState.FAILED, + TaskState.COMPLETED, + TaskState.CANCELED, + ]: + await self._send_push_notification_for_task( + task_id=task_id, state=state, message_text=message_text + ) + except Exception as e: + logger.error(f"Erro ao atualizar status da tarefa {task_id}: {str(e)}") + + async def _add_task_artifact(self, task_id: str, artifact: Artifact) -> None: + """ + Add an artifact to a task and publish the update. + + Args: + task_id: Task ID + artifact: Artifact to add + """ + logger.info(f"Adding artifact to task {task_id}") + + # Update task data + task_data = await self.redis_cache.get(f"task:{task_id}") + if task_data: + if "artifacts" not in task_data: + task_data["artifacts"] = [] + + # Convert artifact to dict + artifact_dict = artifact.model_dump() + task_data["artifacts"].append(artifact_dict) + await self.redis_cache.set(f"task:{task_id}", task_data) + + # Create artifact update event + event = TaskArtifactUpdateEvent(id=task_id, artifact=artifact) + + # Publish event + await self._publish_task_update(task_id, event) + + async def _publish_task_update( + self, task_id: str, event: Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent] + ) -> None: + """ + Publish a task update event to all subscribers. + + Args: + task_id: Task ID + event: Event to publish + """ + async with self.subscriber_lock: + if task_id not in self.task_sse_subscribers: + return + + subscribers = self.task_sse_subscribers[task_id] + for subscriber in subscribers: + try: + await subscriber.put(event) + except Exception as e: + logger.error(f"Error publishing event to subscriber: {str(e)}") + + async def _send_push_notification_for_task( + self, + task_id: str, + state: str, + message_text: str = None, + system_message: str = None, + ) -> None: + """ + Send push notification for a task if configured. + + Args: + task_id: Task ID + state: Task state + message_text: Optional message text + system_message: Optional system message + """ + if not self.push_notification_service: + return + + try: + # Get push notification config + config = await self.redis_cache.get(f"task:{task_id}:push") + if not config: + return + + # Create message if provided + message = None + if message_text: + message = { + "role": "agent", + "parts": [{"type": "text", "text": message_text}], + } + elif system_message: + # We use 'agent' instead of 'system' since Message only accepts 'user' or 'agent' + message = { + "role": "agent", + "parts": [{"type": "text", "text": system_message}], + } + + # Send notification + await self.push_notification_service.send_notification( + url=config["url"], + task_id=task_id, + state=state, + message=message, + headers=config.get("headers", {}), + ) + + except Exception as e: + logger.error( + f"Error sending push notification for task {task_id}: {str(e)}" + ) + + async def _setup_sse_consumer( + self, task_id: str, is_resubscribe: bool = False + ) -> asyncio.Queue: + """ + Set up an SSE consumer queue for a task. + + Args: + task_id: Task ID + is_resubscribe: Whether this is a resubscription + + Returns: + Queue for events + + Raises: + ValueError: If resubscribing to non-existent task + """ + async with self.subscriber_lock: + if task_id not in self.task_sse_subscribers: + if is_resubscribe: + raise ValueError("Task not found for resubscription") + self.task_sse_subscribers[task_id] = [] + + queue = asyncio.Queue() + self.task_sse_subscribers[task_id].append(queue) + return queue + + async def _dequeue_events_for_sse( + self, request_id: str, task_id: str, event_queue: asyncio.Queue + ) -> AsyncIterable[SendTaskStreamingResponse]: + """ + Dequeue and yield events for SSE streaming. + + Args: + request_id: Request ID + task_id: Task ID + event_queue: Queue for events + + Yields: + SSE events wrapped in SendTaskStreamingResponse + """ + try: + while True: + event = await event_queue.get() + + if isinstance(event, JSONRPCError): + yield SendTaskStreamingResponse(id=request_id, error=event) + break + + yield SendTaskStreamingResponse(id=request_id, result=event) + + # Check if this is the final event + is_final = False + if hasattr(event, "final") and event.final: + is_final = True + + if is_final: + break + finally: + # Clean up the subscription when done + async with self.subscriber_lock: + if task_id in self.task_sse_subscribers: + try: + self.task_sse_subscribers[task_id].remove(event_queue) + # Remove the task from the dict if no more subscribers + if not self.task_sse_subscribers[task_id]: + del self.task_sse_subscribers[task_id] + except ValueError: + pass # Queue might have been removed already diff --git a/src/services/agent_service.py b/src/services/agent_service.py index 279ac899..c37f4981 100644 --- a/src/services/agent_service.py +++ b/src/services/agent_service.py @@ -3,7 +3,7 @@ from sqlalchemy.exc import SQLAlchemyError from fastapi import HTTPException, status from src.models.models import Agent from src.schemas.schemas import AgentCreate -from typing import List, Optional, Dict, Any +from typing import List, Optional, Dict, Any, Union from src.services.mcp_server_service import get_mcp_server import uuid import logging @@ -20,9 +20,17 @@ def validate_sub_agents(db: Session, sub_agents: List[uuid.UUID]) -> bool: return True -def get_agent(db: Session, agent_id: uuid.UUID) -> Optional[Agent]: +def get_agent(db: Session, agent_id: Union[uuid.UUID, str]) -> Optional[Agent]: """Search for an agent by ID""" try: + # Convert to UUID if it's a string + if isinstance(agent_id, str): + try: + agent_id = uuid.UUID(agent_id) + except ValueError: + logger.warning(f"Invalid agent ID: {agent_id}") + return None + agent = db.query(Agent).filter(Agent.id == agent_id).first() if not agent: logger.warning(f"Agent not found: {agent_id}") diff --git a/src/services/push_notification_auth_service.py b/src/services/push_notification_auth_service.py new file mode 100644 index 00000000..ca4174ec --- /dev/null +++ b/src/services/push_notification_auth_service.py @@ -0,0 +1,265 @@ +""" +Push Notification Authentication Service. + +This service implements JWT authentication for A2A push notifications, +allowing agents to send authenticated notifications and clients to verify +the authenticity of received notifications. +""" + +from jwcrypto import jwk +import uuid +import time +import json +import hashlib +import httpx +import logging +import jwt +from jwt import PyJWK, PyJWKClient +from fastapi import Request +from starlette.responses import JSONResponse +from typing import Dict, Any, Optional + +logger = logging.getLogger(__name__) +AUTH_HEADER_PREFIX = "Bearer " + + +class PushNotificationAuth: + """ + Base class for push notification authentication. + + Contains common methods used by both the sender and the receiver + of push notifications. + """ + + def _calculate_request_body_sha256(self, data: Dict[str, Any]) -> str: + """ + Calculates the SHA256 hash of the request body. + + This logic needs to be the same for the agent that signs the payload + and for the client that verifies it. + + Args: + data: Request body data + + Returns: + SHA256 hash as a hexadecimal string + """ + body_str = json.dumps( + data, + ensure_ascii=False, + allow_nan=False, + indent=None, + separators=(",", ":"), + ) + return hashlib.sha256(body_str.encode()).hexdigest() + + +class PushNotificationSenderAuth(PushNotificationAuth): + """ + Authentication for the push notification sender. + + This class is used by the A2A server to authenticate notifications + sent to callback URLs of clients. + """ + + def __init__(self): + """ + Initialize the push notification sender authentication service. + """ + self.public_keys = [] + self.private_key_jwk = None + + @staticmethod + async def verify_push_notification_url(url: str) -> bool: + """ + Verifies if a push notification URL is valid and responds correctly. + + Sends a validation token and verifies if the response contains the same token. + + Args: + url: URL to be verified + + Returns: + True if the URL is verified successfully, False otherwise + """ + async with httpx.AsyncClient(timeout=10) as client: + try: + validation_token = str(uuid.uuid4()) + response = await client.get( + url, params={"validationToken": validation_token} + ) + response.raise_for_status() + is_verified = response.text == validation_token + + logger.info(f"Push notification URL verified: {url} => {is_verified}") + return is_verified + except Exception as e: + logger.warning(f"Error verifying push notification URL {url}: {e}") + + return False + + def generate_jwk(self): + """ + Generates a new JWK pair for signing. + + The key pair is used to sign push notifications. + The public key is available via the JWKS endpoint. + """ + key = jwk.JWK.generate(kty="RSA", size=2048, kid=str(uuid.uuid4()), use="sig") + self.public_keys.append(key.export_public(as_dict=True)) + self.private_key_jwk = PyJWK.from_json(key.export_private()) + + def handle_jwks_endpoint(self, _request: Request) -> JSONResponse: + """ + Handles the JWKS endpoint to allow clients to obtain the public keys. + + Args: + _request: HTTP request + + Returns: + JSON response with the public keys + """ + return JSONResponse({"keys": self.public_keys}) + + def _generate_jwt(self, data: Dict[str, Any]) -> str: + """ + Generates a JWT token by signing the SHA256 hash of the payload and the timestamp. + + The payload is signed with the private key to ensure integrity. + The timestamp (iat) prevents replay attacks. + + Args: + data: Payload data + + Returns: + Signed JWT token + """ + iat = int(time.time()) + + return jwt.encode( + { + "iat": iat, + "request_body_sha256": self._calculate_request_body_sha256(data), + }, + key=self.private_key_jwk.key, + headers={"kid": self.private_key_jwk.key_id}, + algorithm="RS256", + ) + + async def send_push_notification(self, url: str, data: Dict[str, Any]) -> bool: + """ + Sends an authenticated push notification to the specified URL. + + Args: + url: URL to send the notification + data: Notification data + + Returns: + True if the notification was sent successfully, False otherwise + """ + if not self.private_key_jwk: + logger.error( + "Attempt to send push notification without generating JWK keys" + ) + return False + + try: + jwt_token = self._generate_jwt(data) + headers = {"Authorization": f"Bearer {jwt_token}"} + + async with httpx.AsyncClient(timeout=10) as client: + response = await client.post(url, json=data, headers=headers) + response.raise_for_status() + logger.info(f"Push notification sent to URL: {url}") + return True + except Exception as e: + logger.warning(f"Error sending push notification to URL {url}: {e}") + return False + + +class PushNotificationReceiverAuth(PushNotificationAuth): + """ + Authentication for the push notification receiver. + + This class is used by clients to verify the authenticity + of notifications received from the A2A server. + """ + + def __init__(self): + """ + Initialize the push notification receiver authentication service. + """ + self.jwks_client = None + + async def load_jwks(self, jwks_url: str): + """ + Loads the public JWKS keys from a URL. + + Args: + jwks_url: URL of the JWKS endpoint + """ + self.jwks_client = PyJWKClient(jwks_url) + + async def verify_push_notification(self, request: Request) -> bool: + """ + Verifies the authenticity of a push notification. + + Args: + request: HTTP request containing the notification + + Returns: + True if the notification is authentic, False otherwise + + Raises: + ValueError: If the token is invalid or expired + """ + if not self.jwks_client: + logger.error("Attempt to verify notification without loading JWKS keys") + return False + + # Verify authentication header + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX): + logger.warning("Invalid authorization header") + return False + + try: + # Extract and verify token + token = auth_header[len(AUTH_HEADER_PREFIX) :] + signing_key = self.jwks_client.get_signing_key_from_jwt(token) + + # Decode token + decode_token = jwt.decode( + token, + signing_key.key, + options={"require": ["iat", "request_body_sha256"]}, + algorithms=["RS256"], + ) + + # Verify request body integrity + body_data = await request.json() + actual_body_sha256 = self._calculate_request_body_sha256(body_data) + if actual_body_sha256 != decode_token["request_body_sha256"]: + # The payload signature does not match the hash in the signed token + logger.warning("Request body hash does not match the token") + raise ValueError("Invalid request body") + + # Verify token age (maximum 5 minutes) + if time.time() - decode_token["iat"] > 60 * 5: + # Do not allow notifications older than 5 minutes + # This prevents replay attacks + logger.warning("Token expired") + raise ValueError("Token expired") + + return True + + except Exception as e: + logger.error(f"Error verifying push notification: {e}") + return False + + +# Global instance of the push notification sender authentication service +push_notification_auth = PushNotificationSenderAuth() + +# Generate JWK keys on initialization +push_notification_auth.generate_jwk() diff --git a/src/services/push_notification_service.py b/src/services/push_notification_service.py index 6473126f..999ba567 100644 --- a/src/services/push_notification_service.py +++ b/src/services/push_notification_service.py @@ -4,6 +4,8 @@ from datetime import datetime from typing import Dict, Any, Optional import asyncio +from src.services.push_notification_auth_service import push_notification_auth + logger = logging.getLogger(__name__) @@ -20,10 +22,24 @@ class PushNotificationService: headers: Optional[Dict[str, str]] = None, max_retries: int = 3, retry_delay: float = 1.0, + use_jwt_auth: bool = True, ) -> bool: """ - Envia uma notificação push para a URL especificada. - Implementa retry com backoff exponencial. + Send a push notification to the specified URL. + Implements exponential backoff retry. + + Args: + url: URL to send the notification + task_id: Task ID + state: Task state + message: Optional message + headers: Optional HTTP headers + max_retries: Maximum number of retries + retry_delay: Initial delay between retries + use_jwt_auth: Whether to use JWT authentication + + Returns: + True if the notification was sent successfully, False otherwise """ payload = { "taskId": task_id, @@ -32,18 +48,37 @@ class PushNotificationService: "message": message, } + # First URL verification + if use_jwt_auth: + is_url_valid = await push_notification_auth.verify_push_notification_url( + url + ) + if not is_url_valid: + logger.warning(f"Invalid push notification URL: {url}") + return False + for attempt in range(max_retries): try: - async with self.session.post( - url, json=payload, headers=headers or {}, timeout=10 - ) as response: - if response.status in (200, 201, 202, 204): + if use_jwt_auth: + # Use JWT authentication + success = await push_notification_auth.send_push_notification( + url, payload + ) + if success: return True - else: - logger.warning( - f"Push notification failed with status {response.status}. " - f"Attempt {attempt + 1}/{max_retries}" - ) + else: + # Legacy method without JWT authentication + async with self.session.post( + url, json=payload, headers=headers or {}, timeout=10 + ) as response: + if response.status in (200, 201, 202, 204): + logger.info(f"Push notification sent to URL: {url}") + return True + else: + logger.warning( + f"Failed to send push notification with status {response.status}. " + f"Attempt {attempt + 1}/{max_retries}" + ) except Exception as e: logger.error( f"Error sending push notification: {str(e)}. " @@ -56,9 +91,9 @@ class PushNotificationService: return False async def close(self): - """Fecha a sessão HTTP""" + """Close the HTTP session""" await self.session.close() -# Instância global do serviço +# Global instance of the push notification service push_notification_service = PushNotificationService() diff --git a/src/services/redis_cache_service.py b/src/services/redis_cache_service.py new file mode 100644 index 00000000..0c4bccda --- /dev/null +++ b/src/services/redis_cache_service.py @@ -0,0 +1,556 @@ +""" +Cache Redis service for the A2A protocol. + +This service provides an interface for storing and retrieving data related to tasks, +push notification configurations, and other A2A-related data. +""" + +import json +import logging +from typing import Any, Dict, List, Optional, Union +import asyncio +import redis.asyncio as aioredis +from redis.exceptions import RedisError +from src.config.redis import get_redis_config, get_a2a_config +import threading +import time + +logger = logging.getLogger(__name__) + + +class _InMemoryCacheFallback: + """ + Fallback in-memory cache implementation for when Redis is not available. + + This should only be used for development or testing environments. + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls): + """Singleton implementation.""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + """Initialize cache storage.""" + if not getattr(self, "_initialized", False): + with self._lock: + if not getattr(self, "_initialized", False): + self._data = {} + self._ttls = {} + self._hash_data = {} + self._list_data = {} + self._data_lock = threading.Lock() + self._initialized = True + logger.warning( + "Initializing in-memory cache fallback (not for production)" + ) + + async def set(self, key, value, ex=None): + """Set a key with optional expiration.""" + with self._data_lock: + self._data[key] = value + if ex is not None: + self._ttls[key] = time.time() + ex + elif key in self._ttls: + del self._ttls[key] + return True + + async def setex(self, key, ex, value): + """Set a key with expiration.""" + return await self.set(key, value, ex) + + async def get(self, key): + """Get a key value.""" + with self._data_lock: + # Check if expired + if key in self._ttls and time.time() > self._ttls[key]: + del self._data[key] + del self._ttls[key] + return None + return self._data.get(key) + + async def delete(self, key): + """Delete a key.""" + with self._data_lock: + if key in self._data: + del self._data[key] + if key in self._ttls: + del self._ttls[key] + return 1 + return 0 + + async def exists(self, key): + """Check if key exists.""" + with self._data_lock: + if key in self._ttls and time.time() > self._ttls[key]: + del self._data[key] + del self._ttls[key] + return 0 + return 1 if key in self._data else 0 + + async def hset(self, key, field, value): + """Set a hash field.""" + with self._data_lock: + if key not in self._hash_data: + self._hash_data[key] = {} + self._hash_data[key][field] = value + return 1 + + async def hget(self, key, field): + """Get a hash field.""" + with self._data_lock: + if key not in self._hash_data: + return None + return self._hash_data[key].get(field) + + async def hdel(self, key, field): + """Delete a hash field.""" + with self._data_lock: + if key in self._hash_data and field in self._hash_data[key]: + del self._hash_data[key][field] + return 1 + return 0 + + async def hgetall(self, key): + """Get all hash fields.""" + with self._data_lock: + if key not in self._hash_data: + return {} + return dict(self._hash_data[key]) + + async def rpush(self, key, value): + """Add to a list.""" + with self._data_lock: + if key not in self._list_data: + self._list_data[key] = [] + self._list_data[key].append(value) + return len(self._list_data[key]) + + async def lrange(self, key, start, end): + """Get range from list.""" + with self._data_lock: + if key not in self._list_data: + return [] + + # Handle negative indices + if end < 0: + end = len(self._list_data[key]) + end + 1 + + return self._list_data[key][start:end] + + async def expire(self, key, seconds): + """Set expiration on key.""" + with self._data_lock: + if key in self._data: + self._ttls[key] = time.time() + seconds + return 1 + return 0 + + async def flushdb(self): + """Clear all data.""" + with self._data_lock: + self._data.clear() + self._ttls.clear() + self._hash_data.clear() + self._list_data.clear() + return True + + async def keys(self, pattern="*"): + """Get keys matching pattern.""" + with self._data_lock: + # Clean expired keys + now = time.time() + expired_keys = [k for k, exp in self._ttls.items() if now > exp] + for k in expired_keys: + if k in self._data: + del self._data[k] + del self._ttls[k] + + # Simple pattern matching + result = [] + if pattern == "*": + result = list(self._data.keys()) + elif pattern.endswith("*"): + prefix = pattern[:-1] + result = [k for k in self._data.keys() if k.startswith(prefix)] + elif pattern.startswith("*"): + suffix = pattern[1:] + result = [k for k in self._data.keys() if k.endswith(suffix)] + else: + if pattern in self._data: + result = [pattern] + + return result + + async def ping(self): + """Test connection.""" + return True + + +class RedisCacheService: + """ + Cache service using Redis for storing A2A-related data. + + This implementation uses a real Redis connection for distributed caching + and data persistence across multiple instances. + + If Redis is not available, falls back to an in-memory implementation. + """ + + def __init__(self, redis_url: Optional[str] = None): + """ + Initialize the Redis cache service. + + Args: + redis_url: Redis server URL (optional, defaults to config value) + """ + if redis_url: + self._redis_url = redis_url + else: + # Construir URL a partir dos componentes de configuração + config = get_redis_config() + protocol = "rediss" if config.get("ssl", False) else "redis" + auth = f":{config['password']}@" if config.get("password") else "" + self._redis_url = ( + f"{protocol}://{auth}{config['host']}:{config['port']}/{config['db']}" + ) + + self._redis = None + self._in_memory_mode = False + self._connecting = False + self._connection_lock = asyncio.Lock() + logger.info(f"Initializing RedisCacheService with URL: {self._redis_url}") + + async def _get_redis(self): + """ + Get a Redis connection, creating it if necessary. + Falls back to in-memory implementation if Redis is not available. + + Returns: + Redis connection or in-memory fallback + """ + if self._redis is not None: + return self._redis + + async with self._connection_lock: + if self._redis is None and not self._connecting: + try: + self._connecting = True + logger.info(f"Connecting to Redis at {self._redis_url}") + self._redis = aioredis.from_url( + self._redis_url, encoding="utf-8", decode_responses=True + ) + # Teste de conexão + await self._redis.ping() + logger.info("Redis connection successful") + except Exception as e: + logger.error(f"Error connecting to Redis: {str(e)}") + logger.warning( + "Falling back to in-memory cache (not suitable for production)" + ) + self._redis = _InMemoryCacheFallback() + self._in_memory_mode = True + finally: + self._connecting = False + + return self._redis + + async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: + """ + Store a value in the cache. + + Args: + key: Key for the value + value: Value to store + ttl: Time to live in seconds (optional) + """ + try: + redis = await self._get_redis() + + # Convert dict/list to JSON string + if isinstance(value, (dict, list)): + value = json.dumps(value) + + if ttl: + await redis.setex(key, ttl, value) + else: + await redis.set(key, value) + + logger.debug(f"Set cache key: {key}") + except Exception as e: + logger.error(f"Error setting Redis key {key}: {str(e)}") + + async def get(self, key: str) -> Optional[Any]: + """ + Retrieve a value from the cache. + + Args: + key: Key for the value to retrieve + + Returns: + The stored value or None if not found + """ + try: + redis = await self._get_redis() + value = await redis.get(key) + + if value is None: + return None + + try: + # Try to parse as JSON + return json.loads(value) + except json.JSONDecodeError: + # Return as-is if not JSON + return value + + except Exception as e: + logger.error(f"Error getting Redis key {key}: {str(e)}") + return None + + async def delete(self, key: str) -> bool: + """ + Remove a value from the cache. + + Args: + key: Key for the value to remove + + Returns: + True if the value was removed, False if it didn't exist + """ + try: + redis = await self._get_redis() + result = await redis.delete(key) + return result > 0 + except Exception as e: + logger.error(f"Error deleting Redis key {key}: {str(e)}") + return False + + async def exists(self, key: str) -> bool: + """ + Check if a key exists in the cache. + + Args: + key: Key to check + + Returns: + True if the key exists, False otherwise + """ + try: + redis = await self._get_redis() + return await redis.exists(key) > 0 + except Exception as e: + logger.error(f"Error checking Redis key {key}: {str(e)}") + return False + + async def set_hash(self, key: str, field: str, value: Any) -> None: + """ + Store a value in a hash. + + Args: + key: Hash key + field: Hash field + value: Value to store + """ + try: + redis = await self._get_redis() + + # Convert dict/list to JSON string + if isinstance(value, (dict, list)): + value = json.dumps(value) + + await redis.hset(key, field, value) + logger.debug(f"Set hash field: {key}:{field}") + except Exception as e: + logger.error(f"Error setting Redis hash {key}:{field}: {str(e)}") + + async def get_hash(self, key: str, field: str) -> Optional[Any]: + """ + Retrieve a value from a hash. + + Args: + key: Hash key + field: Hash field + + Returns: + The stored value or None if not found + """ + try: + redis = await self._get_redis() + value = await redis.hget(key, field) + + if value is None: + return None + + try: + # Try to parse as JSON + return json.loads(value) + except json.JSONDecodeError: + # Return as-is if not JSON + return value + + except Exception as e: + logger.error(f"Error getting Redis hash {key}:{field}: {str(e)}") + return None + + async def delete_hash(self, key: str, field: str) -> bool: + """ + Remove a value from a hash. + + Args: + key: Hash key + field: Hash field + + Returns: + True if the value was removed, False if it didn't exist + """ + try: + redis = await self._get_redis() + result = await redis.hdel(key, field) + return result > 0 + except Exception as e: + logger.error(f"Error deleting Redis hash {key}:{field}: {str(e)}") + return False + + async def get_all_hash(self, key: str) -> Dict[str, Any]: + """ + Retrieve all values from a hash. + + Args: + key: Hash key + + Returns: + Dictionary with all hash values + """ + try: + redis = await self._get_redis() + result_dict = await redis.hgetall(key) + + if not result_dict: + return {} + + # Try to parse each value as JSON + parsed_dict = {} + for field, value in result_dict.items(): + try: + parsed_dict[field] = json.loads(value) + except json.JSONDecodeError: + parsed_dict[field] = value + + return parsed_dict + + except Exception as e: + logger.error(f"Error getting all Redis hash fields for {key}: {str(e)}") + return {} + + async def push_list(self, key: str, value: Any) -> int: + """ + Add a value to the end of a list. + + Args: + key: List key + value: Value to add + + Returns: + Size of the list after the addition + """ + try: + redis = await self._get_redis() + + # Convert dict/list to JSON string + if isinstance(value, (dict, list)): + value = json.dumps(value) + + return await redis.rpush(key, value) + except Exception as e: + logger.error(f"Error pushing to Redis list {key}: {str(e)}") + return 0 + + async def get_list(self, key: str, start: int = 0, end: int = -1) -> List[Any]: + """ + Retrieve values from a list. + + Args: + key: List key + start: Initial index (inclusive) + end: Final index (inclusive), -1 for all + + Returns: + List with the retrieved values + """ + try: + redis = await self._get_redis() + values = await redis.lrange(key, start, end) + + if not values: + return [] + + # Try to parse each value as JSON + result = [] + for value in values: + try: + result.append(json.loads(value)) + except json.JSONDecodeError: + result.append(value) + + return result + + except Exception as e: + logger.error(f"Error getting Redis list {key}: {str(e)}") + return [] + + async def expire(self, key: str, ttl: int) -> bool: + """ + Set a time-to-live for a key. + + Args: + key: Key + ttl: Time-to-live in seconds + + Returns: + True if the key exists and the TTL was set, False otherwise + """ + try: + redis = await self._get_redis() + return await redis.expire(key, ttl) + except Exception as e: + logger.error(f"Error setting expire for Redis key {key}: {str(e)}") + return False + + async def clear(self) -> None: + """ + Clear the entire cache. + + Warning: This is a destructive operation and will remove all data. + Only use in development/test environments. + """ + try: + redis = await self._get_redis() + await redis.flushdb() + logger.warning("Redis database flushed - all data cleared") + except Exception as e: + logger.error(f"Error clearing Redis database: {str(e)}") + + async def keys(self, pattern: str = "*") -> List[str]: + """ + Retrieve keys that match a pattern. + + Args: + pattern: Glob pattern to filter keys + + Returns: + List of keys that match the pattern + """ + try: + redis = await self._get_redis() + return await redis.keys(pattern) + except Exception as e: + logger.error(f"Error getting Redis keys with pattern {pattern}: {str(e)}") + return [] diff --git a/src/services/streaming_service.py b/src/services/streaming_service.py index ff385b1f..4cee4455 100644 --- a/src/services/streaming_service.py +++ b/src/services/streaming_service.py @@ -3,12 +3,12 @@ import json from typing import AsyncGenerator, Dict, Any from fastapi import HTTPException from datetime import datetime -from ..schemas.streaming import ( +from src.schemas.streaming import ( JSONRPCRequest, TaskStatusUpdateEvent, ) -from ..services.agent_runner import run_agent -from ..services.service_providers import ( +from src.services.agent_runner import run_agent +from src.services.service_providers import ( session_service, artifacts_service, memory_service, @@ -25,31 +25,33 @@ class StreamingService: agent_id: str, api_key: str, message: str, + contact_id: str = None, session_id: str = None, db: Session = None, ) -> AsyncGenerator[str, None]: """ - Inicia o streaming de eventos SSE para uma tarefa. + Starts the SSE event streaming for a task. Args: - agent_id: ID do agente - api_key: Chave de API para autenticação - message: Mensagem inicial - session_id: ID da sessão (opcional) - db: Sessão do banco de dados + agent_id: Agent ID + api_key: API key for authentication + message: Initial message + contact_id: Contact ID (optional) + session_id: Session ID (optional) + db: Database session Yields: - Eventos SSE formatados + Formatted SSE events """ - # Validação básica da API key + # Basic API key validation if not api_key: - raise HTTPException(status_code=401, detail="API key é obrigatória") + raise HTTPException(status_code=401, detail="API key is required") - # Gera IDs únicos - task_id = str(uuid.uuid4()) + # Generate unique IDs + task_id = contact_id or str(uuid.uuid4()) request_id = str(uuid.uuid4()) - # Monta payload JSON-RPC + # Build JSON-RPC payload payload = JSONRPCRequest( id=request_id, params={ @@ -62,7 +64,7 @@ class StreamingService: }, ) - # Registra conexão + # Register connection self.active_connections[task_id] = { "agent_id": agent_id, "api_key": api_key, @@ -70,7 +72,7 @@ class StreamingService: } try: - # Envia evento de início + # Send start event yield self._format_sse_event( "status", TaskStatusUpdateEvent( @@ -80,10 +82,10 @@ class StreamingService: ).model_dump_json(), ) - # Executa o agente + # Execute the agent result = await run_agent( str(agent_id), - task_id, + contact_id or task_id, message, session_service, artifacts_service, @@ -92,7 +94,7 @@ class StreamingService: session_id, ) - # Envia a resposta do agente como um evento separado + # Send the agent's response as a separate event yield self._format_sse_event( "message", json.dumps( @@ -104,7 +106,7 @@ class StreamingService: ), ) - # Evento de conclusão + # Completion event yield self._format_sse_event( "status", TaskStatusUpdateEvent( @@ -114,7 +116,7 @@ class StreamingService: ) except Exception as e: - # Evento de erro + # Error event yield self._format_sse_event( "status", TaskStatusUpdateEvent( @@ -126,14 +128,14 @@ class StreamingService: raise finally: - # Limpa conexão + # Clean connection self.active_connections.pop(task_id, None) def _format_sse_event(self, event_type: str, data: str) -> str: - """Formata um evento SSE.""" + """Format an SSE event.""" return f"event: {event_type}\ndata: {data}\n\n" async def close_connection(self, task_id: str): - """Fecha uma conexão de streaming.""" + """Close a streaming connection.""" if task_id in self.active_connections: self.active_connections.pop(task_id) diff --git a/src/utils/a2a_utils.py b/src/utils/a2a_utils.py new file mode 100644 index 00000000..e2a93e4b --- /dev/null +++ b/src/utils/a2a_utils.py @@ -0,0 +1,110 @@ +""" +A2A protocol utilities. + +This module contains utility functions for the A2A protocol implementation. +""" + +import logging +from typing import List, Optional, Any, Dict +from src.schemas.a2a import ( + ContentTypeNotSupportedError, + UnsupportedOperationError, + JSONRPCResponse, +) + +logger = logging.getLogger(__name__) + + +def are_modalities_compatible( + server_output_modes: Optional[List[str]], client_output_modes: Optional[List[str]] +) -> bool: + """ + Check if server and client modalities are compatible. + + Modalities are compatible if they are both non-empty + and there is at least one common element. + + Args: + server_output_modes: List of output modes supported by the server + client_output_modes: List of output modes requested by the client + + Returns: + True if compatible, False otherwise + """ + # If client doesn't specify modes, assume all are accepted + if client_output_modes is None or len(client_output_modes) == 0: + return True + + # If server doesn't specify modes, assume all are supported + if server_output_modes is None or len(server_output_modes) == 0: + return True + + # Check if there's at least one common mode + return any(mode in server_output_modes for mode in client_output_modes) + + +def new_incompatible_types_error(request_id: str) -> JSONRPCResponse: + """ + Create a JSON-RPC response for incompatible content types error. + + Args: + request_id: The ID of the request that caused the error + + Returns: + JSON-RPC response with ContentTypeNotSupportedError + """ + return JSONRPCResponse(id=request_id, error=ContentTypeNotSupportedError()) + + +def new_not_implemented_error(request_id: str) -> JSONRPCResponse: + """ + Create a JSON-RPC response for unimplemented operation error. + + Args: + request_id: The ID of the request that caused the error + + Returns: + JSON-RPC response with UnsupportedOperationError + """ + return JSONRPCResponse(id=request_id, error=UnsupportedOperationError()) + + +def create_task_id(agent_id: str, session_id: str, timestamp: str = None) -> str: + """ + Create a unique task ID for an agent and session. + + Args: + agent_id: The ID of the agent + session_id: The ID of the session + timestamp: Optional timestamp to include in the ID + + Returns: + A unique task ID + """ + import uuid + import time + + timestamp = timestamp or str(int(time.time())) + unique = uuid.uuid4().hex[:8] + + return f"{agent_id}_{session_id}_{timestamp}_{unique}" + + +def format_error_response(error: Exception, request_id: str = None) -> Dict[str, Any]: + """ + Format an exception as a JSON-RPC error response. + + Args: + error: The exception to format + request_id: The ID of the request that caused the error + + Returns: + JSON-RPC error response as dictionary + """ + from src.schemas.a2a import InternalError, JSONRPCResponse + + error_response = JSONRPCResponse( + id=request_id, error=InternalError(message=str(error)) + ) + + return error_response.model_dump(exclude_none=True)