feat(a2a): add Redis configuration and A2A service implementation

This commit is contained in:
Davidson Gomes 2025-04-30 16:39:48 -03:00
parent 910979fb83
commit 4901be8e4c
26 changed files with 4630 additions and 1158 deletions

3
.env
View File

@ -18,6 +18,9 @@ REDIS_HOST="localhost"
REDIS_PORT=6379
REDIS_DB=8
REDIS_PASSWORD=""
REDIS_SSL=false
REDIS_KEY_PREFIX="a2a:"
REDIS_TTL=3600
# Tools cache TTL in seconds (1 hour)
TOOLS_CACHE_TTL=3600

View File

@ -18,6 +18,9 @@ REDIS_HOST="localhost"
REDIS_PORT=6379
REDIS_DB=0
REDIS_PASSWORD="your-redis-password"
REDIS_SSL=false
REDIS_KEY_PREFIX="a2a:"
REDIS_TTL=3600
# Tools cache TTL in seconds (1 hour)
TOOLS_CACHE_TTL=3600
@ -44,3 +47,9 @@ ADMIN_INITIAL_PASSWORD="strongpassword123"
DEMO_EMAIL="demo@example.com"
DEMO_PASSWORD="demo123"
DEMO_CLIENT_NAME="Demo Client"
# A2A settings
A2A_TASK_TTL=3600
A2A_HISTORY_TTL=86400
A2A_PUSH_NOTIFICATION_TTL=3600
A2A_SSE_CLIENT_TTL=300

View File

@ -1,252 +0,0 @@
# Checklist de Implementação do Protocolo A2A com Redis
## 1. Configuração Inicial
- [ ] **Configurar dependências no arquivo pyproject.toml**
- Adicionar Redis e dependências relacionadas:
```
redis = "^5.3.0"
sse-starlette = "^2.3.3"
jwcrypto = "^1.5.6"
pyjwt = {extras = ["crypto"], version = "^2.10.1"}
```
- [ ] **Configurar variáveis de ambiente para Redis**
- Adicionar em `.env.example` e `.env`:
```
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_PASSWORD=
REDIS_DB=0
REDIS_SSL=false
REDIS_KEY_PREFIX=a2a:
REDIS_TTL=3600
```
- [ ] **Configurar Redis no docker-compose.yml**
- Adicionar serviço Redis com portas e volumes apropriados
- Configurar segurança básica (senha, se necessário)
## 2. Implementação de Modelos e Schemas
- [ ] **Criar schemas A2A em `src/schemas/a2a.py`**
- Implementar tipos conforme `docs/A2A/samples/python/common/types.py`:
- Enums (TaskState, etc.)
- Classes de mensagens (TextPart, FilePart, etc.)
- Classes de tarefas (Task, TaskStatus, etc.)
- Estruturas JSON-RPC
- Tipos de erros
- [ ] **Implementar validadores de modelo**
- Validadores para conteúdos de arquivo
- Validadores para formatos de mensagem
- Conversores de formato para compatibilidade com o protocolo
## 3. Implementação do Cache Redis
- [ ] **Criar configuração Redis em `src/config/redis.py`**
- Implementar função de conexão com pool
- Configurar opções de segurança (SSL, autenticação)
- Configurar TTL padrão para diferentes tipos de dados
- [ ] **Criar serviço de cache Redis em `src/services/redis_cache_service.py`**
- Implementar métodos do exemplo com suporte a Redis:
```python
class RedisCache:
async def get_task(self, task_id: str) -> dict
async def save_task(self, task_id: str, task_data: dict, ttl: int = 3600) -> None
async def update_task_status(self, task_id: str, status: dict) -> bool
async def append_to_history(self, task_id: str, message: dict) -> bool
async def save_push_notification_config(self, task_id: str, config: dict) -> None
async def get_push_notification_config(self, task_id: str) -> dict
async def save_sse_client(self, task_id: str, client_id: str) -> None
async def get_sse_clients(self, task_id: str) -> list
async def remove_sse_client(self, task_id: str, client_id: str) -> None
```
- [ ] **Implementar funcionalidades para gerenciamento de conexões**
- Reconexão automática
- Fallback para cache em memória em caso de falha
- Métricas de desempenho
## 4. Serviços A2A
- [ ] **Implementar utilitários A2A em `src/utils/a2a_utils.py`**
- Implementar funções conforme `docs/A2A/samples/python/common/server/utils.py`:
```python
def are_modalities_compatible(server_output_modes, client_output_modes)
def new_incompatible_types_error(request_id)
def new_not_implemented_error(request_id)
```
- [ ] **Implementar `A2ATaskManager` em `src/services/a2a_task_manager_service.py`**
- Seguir a interface do `TaskManager` do exemplo
- Implementar todos os métodos abstratos:
```python
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse
async def on_send_task_subscribe(self, request) -> Union[AsyncIterable, JSONRPCResponse]
async def on_set_task_push_notification(self, request) -> SetTaskPushNotificationResponse
async def on_get_task_push_notification(self, request) -> GetTaskPushNotificationResponse
async def on_resubscribe_to_task(self, request) -> Union[AsyncIterable, JSONRPCResponse]
```
- Utilizar Redis para persistência de dados de tarefa
- [ ] **Implementar `A2AServer` em `src/services/a2a_server_service.py`**
- Processar requisições JSON-RPC conforme `docs/A2A/samples/python/common/server/server.py`:
```python
async def _process_request(self, request: Request)
def _handle_exception(self, e: Exception) -> JSONResponse
def _create_response(self, result: Any) -> Union[JSONResponse, EventSourceResponse]
```
- [ ] **Integrar com agent_runner.py existente**
- Adaptar `run_agent` para uso no contexto de tarefas A2A
- Implementar mapeamento entre formatos de mensagem
- [ ] **Integrar com streaming_service.py existente**
- Adaptar para formato de eventos compatível com A2A
- Implementar suporte a streaming de múltiplos tipos de eventos
## 5. Autenticação e Push Notifications
- [ ] **Implementar `PushNotificationAuth` em `src/services/push_notification_auth_service.py`**
- Seguir o exemplo em `docs/A2A/samples/python/common/utils/push_notification_auth.py`
- Implementar:
```python
def generate_jwk(self)
def handle_jwks_endpoint(self, request: Request)
async def send_authenticated_push_notification(self, url: str, data: dict)
```
- [ ] **Implementar verificação de URL de notificação**
- Seguir método `verify_push_notification_url` do exemplo
- Implementar validação de token para verificação
- [ ] **Implementar armazenamento seguro de chaves**
- Armazenar chaves privadas de forma segura
- Rotação periódica de chaves
- Gerenciamento do ciclo de vida das chaves
## 6. Rotas A2A
- [ ] **Implementar rotas em `src/api/a2a_routes.py`**
- Criar endpoint principal para processamento de requisições JSON-RPC:
```python
@router.post("/{agent_id}")
async def process_a2a_request(agent_id: str, request: Request, x_api_key: str = Header(None))
```
- Implementar endpoint do Agent Card reutilizando lógica existente:
```python
@router.get("/{agent_id}/.well-known/agent.json")
async def get_agent_card(agent_id: str, db: Session = Depends(get_db))
```
- Implementar endpoint JWKS para autenticação de push notifications:
```python
@router.get("/{agent_id}/.well-known/jwks.json")
async def get_jwks(agent_id: str, db: Session = Depends(get_db))
```
- [ ] **Registrar rotas A2A no aplicativo principal**
- Adicionar importação e inclusão em `src/main.py`:
```python
app.include_router(a2a_routes.router, prefix="/api/v1")
```
## 7. Testes
- [ ] **Criar testes unitários para schemas A2A**
- Testar validadores
- Testar conversões de formato
- Testar compatibilidade de modalidades
- [ ] **Criar testes unitários para cache Redis**
- Testar todas as operações CRUD
- Testar expiração de dados
- Testar comportamento com falhas de conexão
- [ ] **Criar testes unitários para gerenciador de tarefas**
- Testar ciclo de vida da tarefa
- Testar cancelamento de tarefas
- Testar notificações push
- [ ] **Criar testes de integração para endpoints A2A**
- Testar requisições completas
- Testar streaming
- Testar cenários de erro
## 8. Segurança
- [ ] **Implementar validação de API key**
- Verificar API key para todas as requisições
- Implementar rate limiting por agente/cliente
- [ ] **Configurar segurança no Redis**
- Ativar autenticação e SSL em produção
- Definir políticas de retenção de dados
- Implementar backup e recuperação
- [ ] **Configurar segurança para push notifications**
- Implementar assinatura JWT
- Validar URLs de callback
- Implementar retry com backoff para falhas
## 9. Monitoramento e Métricas
- [ ] **Implementar métricas de Redis**
- Taxa de acertos/erros do cache
- Tempo de resposta
- Uso de memória
- [ ] **Implementar métricas de tarefas A2A**
- Número de tarefas por estado
- Tempo médio de processamento
- Taxa de erros
- [ ] **Configurar logging apropriado**
- Registrar eventos importantes
- Mascarar dados sensíveis
- Implementar níveis de log configuráveis
## 10. Documentação
- [ ] **Documentar API A2A**
- Descrever endpoints e formatos
- Fornecer exemplos de uso
- Documentar erros e soluções
- [ ] **Documentar integração com Redis**
- Descrever configuração
- Explicar estratégia de cache
- Documentar TTLs e políticas de expiração
- [ ] **Criar exemplos de clients**
- Implementar exemplos de uso em Python
- Documentar fluxos comuns
- Fornecer snippets para linguagens populares

187
a2a_client_test.py Normal file
View File

@ -0,0 +1,187 @@
import logging
import httpx
from httpx_sse import connect_sse
from typing import Any, AsyncIterable, Optional
from docs.A2A.samples.python.common.types import (
AgentCard,
GetTaskRequest,
SendTaskRequest,
SendTaskResponse,
JSONRPCRequest,
GetTaskResponse,
CancelTaskResponse,
CancelTaskRequest,
SetTaskPushNotificationRequest,
SetTaskPushNotificationResponse,
GetTaskPushNotificationRequest,
GetTaskPushNotificationResponse,
A2AClientHTTPError,
A2AClientJSONError,
SendTaskStreamingRequest,
SendTaskStreamingResponse,
)
import json
import asyncio
import uuid
# Configurar logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("a2a_client_runner")
class A2ACardResolver:
def __init__(self, base_url, agent_card_path="/.well-known/agent.json"):
self.base_url = base_url.rstrip("/")
self.agent_card_path = agent_card_path.lstrip("/")
def get_agent_card(self) -> AgentCard:
with httpx.Client() as client:
response = client.get(self.base_url + "/" + self.agent_card_path)
response.raise_for_status()
try:
return AgentCard(**response.json())
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
class A2AClient:
def __init__(
self,
agent_card: AgentCard = None,
url: str = None,
api_key: Optional[str] = None,
):
if agent_card:
self.url = agent_card.url
elif url:
self.url = url
else:
raise ValueError("Must provide either agent_card or url")
self.api_key = api_key
self.headers = {"x-api-key": api_key} if api_key else {}
async def send_task(self, payload: dict[str, Any]) -> SendTaskResponse:
request = SendTaskRequest(params=payload)
return SendTaskResponse(**await self._send_request(request))
async def send_task_streaming(
self, payload: dict[str, Any]
) -> AsyncIterable[SendTaskStreamingResponse]:
request = SendTaskStreamingRequest(params=payload)
with httpx.Client(timeout=None) as client:
with connect_sse(
client,
"POST",
self.url,
json=request.model_dump(),
headers=self.headers,
) as event_source:
try:
for sse in event_source.iter_sse():
yield SendTaskStreamingResponse(**json.loads(sse.data))
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(400, str(e)) from e
async def _send_request(self, request: JSONRPCRequest) -> dict[str, Any]:
async with httpx.AsyncClient() as client:
try:
# Image generation could take time, adding timeout
response = await client.post(
self.url,
json=request.model_dump(),
headers=self.headers,
timeout=30,
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
async def get_task(self, payload: dict[str, Any]) -> GetTaskResponse:
request = GetTaskRequest(params=payload)
return GetTaskResponse(**await self._send_request(request))
async def cancel_task(self, payload: dict[str, Any]) -> CancelTaskResponse:
request = CancelTaskRequest(params=payload)
return CancelTaskResponse(**await self._send_request(request))
async def set_task_callback(
self, payload: dict[str, Any]
) -> SetTaskPushNotificationResponse:
request = SetTaskPushNotificationRequest(params=payload)
return SetTaskPushNotificationResponse(**await self._send_request(request))
async def get_task_callback(
self, payload: dict[str, Any]
) -> GetTaskPushNotificationResponse:
request = GetTaskPushNotificationRequest(params=payload)
return GetTaskPushNotificationResponse(**await self._send_request(request))
async def main():
# Configurações
BASE_URL = "http://localhost:8000/api/v1/a2a/18a2889e-8573-4e70-833c-7d9e00a8fd80"
API_KEY = "83c2c19f-dc2e-4abe-9a41-ef7d2eb079d6"
try:
# Obter o card do agente
logger.info("Obtendo card do agente...")
card_resolver = A2ACardResolver(BASE_URL)
try:
card = card_resolver.get_agent_card()
logger.info(f"Card do agente: {card}")
except Exception as e:
logger.error(f"Erro ao obter card do agente: {e}")
return
# Criar cliente A2A com API key
client = A2AClient(card, api_key=API_KEY)
# Exemplo 1: Enviar tarefa síncrona
logger.info("\n=== TESTE DE TAREFA SÍNCRONA ===")
task_id = str(uuid.uuid4())
session_id = "test-session-1"
# Preparar payload da tarefa
payload = {
"id": task_id,
"sessionId": session_id,
"message": {
"role": "user",
"parts": [
{
"type": "text",
"text": "Quais são os três maiores países do mundo em área territorial?",
}
],
},
}
logger.info(f"Enviando tarefa com ID: {task_id}")
async for streaming_response in client.send_task_streaming(payload):
if hasattr(streaming_response.result, "artifact"):
# Processar conteúdo parcial
print(streaming_response.result.artifact.parts[0].text)
elif (
hasattr(streaming_response.result, "status")
and streaming_response.result.status.state == "completed"
):
# Tarefa concluída
print(
"Resposta final:",
streaming_response.result.status.message.parts[0].text,
)
except Exception as e:
logger.error(f"Erro durante execução dos testes: {e}")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,491 +0,0 @@
# Implementação do Servidor A2A (Agent-to-Agent)
## Visão Geral
Este documento descreve o plano de implementação para integrar o servidor A2A (Agent-to-Agent) no sistema existente. A implementação seguirá os exemplos fornecidos nos arquivos de referência, adaptando-os à estrutura atual do projeto.
## Componentes a Serem Implementados
### 1. Servidor A2A
Implementação da classe `A2AServer` como serviço para gerenciar requisições JSON-RPC compatíveis com o protocolo A2A.
### 2. Gerenciador de Tarefas
Implementação do `TaskManager` como serviço para gerenciar o ciclo de vida das tarefas do agente.
### 3. Adaptadores para Integração
Criação de adaptadores para integrar o servidor A2A com os serviços existentes, como o streaming_service e push_notification_service.
## Rotas e Endpoints A2A
O protocolo A2A requer a implementação das seguintes rotas JSON-RPC:
### 1. `POST /a2a/{agent_id}`
Endpoint principal que processa todas as solicitações JSON-RPC para um agente específico.
### 2. `GET /a2a/{agent_id}/.well-known/agent.json`
Retorna o Agent Card contendo as informações do agente.
### Métodos JSON-RPC implementados:
1. **tasks/send** - Envia uma nova tarefa para o agente.
- Parâmetros: `id`, `sessionId`, `message`, `acceptedOutputModes` (opcional), `pushNotification` (opcional)
- Retorna: Status da tarefa, artefatos gerados e histórico
2. **tasks/sendSubscribe** - Envia uma tarefa e assina para receber atualizações via streaming (SSE).
- Parâmetros: Mesmos do `tasks/send`
- Retorna: Stream de eventos SSE com atualizações de status e artefatos
3. **tasks/get** - Obtém o status atual de uma tarefa.
- Parâmetros: `id`, `historyLength` (opcional)
- Retorna: Status atual da tarefa, artefatos e histórico
4. **tasks/cancel** - Tenta cancelar uma tarefa em execução.
- Parâmetros: `id`
- Retorna: Status atualizado da tarefa
5. **tasks/pushNotification/set** - Configura notificações push para uma tarefa.
- Parâmetros: `id`, `pushNotificationConfig` (URL e autenticação)
- Retorna: Configuração de notificação atualizada
6. **tasks/pushNotification/get** - Obtém a configuração de notificações push de uma tarefa.
- Parâmetros: `id`
- Retorna: Configuração de notificação atual
7. **tasks/resubscribe** - Reassina para receber eventos de uma tarefa existente.
- Parâmetros: `id`
- Retorna: Stream de eventos SSE
## Streaming e Push Notifications
### Streaming (SSE)
O streaming será implementado usando Server-Sent Events (SSE) para enviar atualizações em tempo real aos clientes.
#### Integração com streaming_service.py existente
O serviço atual `StreamingService` já implementa funcionalidades de streaming SSE. Iremos expandir e adaptar:
```python
# Exemplo de integração com o streaming_service.py existente
async def send_task_streaming(request: SendTaskStreamingRequest) -> AsyncIterable[SendTaskStreamingResponse]:
stream_service = StreamingService()
async for event in stream_service.send_task_streaming(
agent_id=request.params.metadata.get("agent_id"),
api_key=api_key,
message=request.params.message.parts[0].text,
session_id=request.params.sessionId,
db=db
):
# Converter formato de evento SSE para formato A2A
yield SendTaskStreamingResponse(
id=request.id,
result=convert_to_a2a_event_format(event)
)
```
### Push Notifications
O sistema de notificações push permitirá que o agente envie atualizações para URLs de callback configuradas pelos clientes.
#### Integração com push_notification_service.py existente
O serviço atual `PushNotificationService` já implementa o envio de notificações. Iremos adaptar:
```python
# Exemplo de integração com o push_notification_service.py existente
async def send_push_notification(task_id, state, message=None):
notification_config = await get_push_notification_config(task_id)
if notification_config:
await push_notification_service.send_notification(
url=notification_config.url,
task_id=task_id,
state=state,
message=message,
headers=notification_config.headers
)
```
#### Autenticação de Push Notifications
Implementaremos autenticação segura para as notificações push baseada em JWT usando o `PushNotificationAuth`:
```python
# Exemplo de como configurar autenticação nas notificações
push_auth = PushNotificationSenderAuth()
push_auth.generate_jwk()
# Incluir rota para obter as chaves públicas
@router.get("/{agent_id}/.well-known/jwks.json")
async def get_jwks(agent_id: str):
return push_auth.handle_jwks_endpoint(request)
# Integrar autenticação ao enviar notificações
async def send_authenticated_push_notification(url, data):
await push_auth.send_push_notification(url, data)
```
## Estratégia de Armazenamento de Dados
### Uso de Redis para Dados Temporários
Utilizaremos Redis para armazenamento e gerenciamento dos dados temporários das tarefas A2A, substituindo o cache em memória do exemplo original:
```python
from src.services.redis_cache_service import RedisCacheService
class RedisCache:
def __init__(self, redis_service: RedisCacheService):
self.redis = redis_service
# Métodos para gerenciamento de tarefas
async def get_task(self, task_id: str) -> dict:
"""Recupera uma tarefa pelo ID."""
return self.redis.get(f"task:{task_id}")
async def save_task(self, task_id: str, task_data: dict, ttl: int = 3600) -> None:
"""Salva uma tarefa com TTL configurável."""
self.redis.set(f"task:{task_id}", task_data, ttl=ttl)
async def update_task_status(self, task_id: str, status: dict) -> bool:
"""Atualiza o status de uma tarefa."""
task_data = await self.get_task(task_id)
if not task_data:
return False
task_data["status"] = status
await self.save_task(task_id, task_data)
return True
# Métodos para histórico de tarefas
async def append_to_history(self, task_id: str, message: dict) -> bool:
"""Adiciona uma mensagem ao histórico da tarefa."""
task_data = await self.get_task(task_id)
if not task_data:
return False
if "history" not in task_data:
task_data["history"] = []
task_data["history"].append(message)
await self.save_task(task_id, task_data)
return True
# Métodos para notificações push
async def save_push_notification_config(self, task_id: str, config: dict) -> None:
"""Salva a configuração de notificação push para uma tarefa."""
self.redis.set(f"push_notification:{task_id}", config, ttl=3600)
async def get_push_notification_config(self, task_id: str) -> dict:
"""Recupera a configuração de notificação push de uma tarefa."""
return self.redis.get(f"push_notification:{task_id}")
# Métodos para SSE (Server-Sent Events)
async def save_sse_client(self, task_id: str, client_id: str) -> None:
"""Registra um cliente SSE para uma tarefa."""
self.redis.set_hash(f"sse_clients:{task_id}", client_id, "active")
async def get_sse_clients(self, task_id: str) -> list:
"""Recupera todos os clientes SSE registrados para uma tarefa."""
return self.redis.get_all_hash(f"sse_clients:{task_id}")
async def remove_sse_client(self, task_id: str, client_id: str) -> None:
"""Remove um cliente SSE do registro."""
self.redis.delete_hash(f"sse_clients:{task_id}", client_id)
```
O serviço Redis será configurado com TTL (time-to-live) para garantir a limpeza automática de dados temporários:
```python
# Configuração de TTL para diferentes tipos de dados
TASK_TTL = 3600 # 1 hora para tarefas
HISTORY_TTL = 86400 # 24 horas para histórico
PUSH_NOTIFICATION_TTL = 3600 # 1 hora para configurações de notificação
SSE_CLIENT_TTL = 300 # 5 minutos para clientes SSE
```
### Modelos Existentes e Redis
O sistema continuará utilizando os modelos SQLAlchemy existentes para dados permanentes:
- **Agent**: Dados do agente e configurações
- **Session**: Sessões persistentes
- **MCPServer**: Configurações de servidores de ferramentas
Para dados temporários (tarefas A2A, histórico, streaming), utilizaremos Redis que oferece:
1. **Performance**: Operações em memória com persistência opcional
2. **TTL**: Expiração automática de dados temporários
3. **Estruturas de dados**: Suporte a strings, hashes, listas para diferentes necessidades
4. **Pub/Sub**: Mecanismo para notificações em tempo real
5. **Escalabilidade**: Melhor suporte a múltiplas instâncias do que cache em memória
### Implementação do TaskManager com Redis
O `A2ATaskManager` implementará a mesma interface que o `TaskManager` do exemplo, mas utilizando Redis:
```python
class A2ATaskManager:
"""
Gerenciador de tarefas A2A usando Redis para armazenamento.
Implementa a interface do protocolo A2A para gerenciamento do ciclo de vida das tarefas.
"""
def __init__(
self,
redis_cache: RedisCacheService,
session_service=None,
artifacts_service=None,
memory_service=None,
push_notification_service=None
):
self.redis_cache = redis_cache
self.session_service = session_service
self.artifacts_service = artifacts_service
self.memory_service = memory_service
self.push_notification_service = push_notification_service
self.lock = asyncio.Lock()
self.subscriber_lock = asyncio.Lock()
self.task_sse_subscribers = {}
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
"""
Obtém o status atual de uma tarefa.
Args:
request: Requisição JSON-RPC para obter dados da tarefa
Returns:
Resposta com dados da tarefa ou erro
"""
logger.info(f"Getting task {request.params.id}")
task_query_params = request.params
task_data = self.redis_cache.get(f"task:{task_query_params.id}")
if not task_data:
return GetTaskResponse(id=request.id, error=TaskNotFoundError())
# Processar histórico conforme solicitado
if task_query_params.historyLength and task_data.get("history"):
task_data["history"] = task_data["history"][-task_query_params.historyLength:]
return GetTaskResponse(id=request.id, result=task_data)
```
## Implementação do A2A Server Service
O serviço `A2AServer` processará as requisições JSON-RPC conforme o protocolo A2A:
```python
class A2AServer:
"""
Servidor A2A que implementa o protocolo JSON-RPC para processamento de tarefas de agentes.
"""
def __init__(
self,
endpoint: str = "/",
agent_card = None,
task_manager = None,
streaming_service = None,
):
self.endpoint = endpoint
self.agent_card = agent_card
self.task_manager = task_manager
self.streaming_service = streaming_service
async def _process_request(self, request: Request):
"""
Processa uma requisição JSON-RPC do protocolo A2A.
Args:
request: Requisição HTTP
Returns:
Resposta JSON-RPC ou stream de eventos
"""
try:
body = await request.json()
json_rpc_request = A2ARequest.validate_python(body)
# Delegar para o handler apropriado com base no tipo de requisição
if isinstance(json_rpc_request, GetTaskRequest):
result = await self.task_manager.on_get_task(json_rpc_request)
elif isinstance(json_rpc_request, SendTaskRequest):
result = await self.task_manager.on_send_task(json_rpc_request)
elif isinstance(json_rpc_request, SendTaskStreamingRequest):
result = await self.task_manager.on_send_task_subscribe(json_rpc_request)
elif isinstance(json_rpc_request, CancelTaskRequest):
result = await self.task_manager.on_cancel_task(json_rpc_request)
elif isinstance(json_rpc_request, SetTaskPushNotificationRequest):
result = await self.task_manager.on_set_task_push_notification(json_rpc_request)
elif isinstance(json_rpc_request, GetTaskPushNotificationRequest):
result = await self.task_manager.on_get_task_push_notification(json_rpc_request)
elif isinstance(json_rpc_request, TaskResubscriptionRequest):
result = await self.task_manager.on_resubscribe_to_task(json_rpc_request)
else:
logger.warning(f"Unexpected request type: {type(json_rpc_request)}")
raise ValueError(f"Unexpected request type: {type(json_rpc_request)}")
return self._create_response(result)
except Exception as e:
return self._handle_exception(e)
```
## Implementação da Autenticação para Push Notifications
Implementaremos autenticação JWT para notificações push, seguindo o exemplo:
```python
class PushNotificationAuth:
def __init__(self):
self.public_keys = []
self.private_key_jwk = None
def generate_jwk(self):
key = jwk.JWK.generate(kty='RSA', size=2048, kid=str(uuid.uuid4()), use="sig")
self.public_keys.append(key.export_public(as_dict=True))
self.private_key_jwk = PyJWK.from_json(key.export_private())
def handle_jwks_endpoint(self, request: Request):
"""Retorna as chaves públicas para clientes."""
return JSONResponse({"keys": self.public_keys})
async def send_authenticated_push_notification(self, url: str, data: dict):
"""Envia notificação push assinada com JWT."""
jwt_token = self._generate_jwt(data)
headers = {'Authorization': f"Bearer {jwt_token}"}
async with httpx.AsyncClient(timeout=10) as client:
try:
response = await client.post(url, json=data, headers=headers)
response.raise_for_status()
logger.info(f"Push-notification sent to URL: {url}")
except Exception as e:
logger.warning(f"Error sending push-notification to URL {url}: {e}")
```
## Revisão das Rotas A2A
As rotas A2A serão implementadas em `src/api/a2a_routes.py`, utilizando a lógica de AgentCard do código existente:
```python
@router.post("/{agent_id}")
async def process_a2a_request(
agent_id: str,
request: Request,
x_api_key: str = Header(None, alias="x-api-key"),
db: Session = Depends(get_db),
):
"""
Endpoint que processa requisições JSON-RPC do protocolo A2A.
"""
# Validar agente e API key
agent = get_agent(db, agent_id)
if not agent:
return JSONResponse(
status_code=404,
content={"detail": "Agente não encontrado"}
)
if agent.config.get("api_key") != x_api_key:
return JSONResponse(
status_code=401,
content={"detail": "Chave API inválida"}
)
# Criar Agent Card para o agente (reutilizando lógica existente)
agent_card = create_agent_card_from_agent(agent, db)
# Configurar o servidor A2A para este agente
a2a_server.agent_card = agent_card
# Processar a requisição A2A
return await a2a_server._process_request(request)
```
## Arquivos a Serem Criados/Atualizados
### Novos Arquivos
1. `src/schemas/a2a.py` - Modelos Pydantic para o protocolo A2A
2. `src/services/redis_cache_service.py` - Serviço de cache Redis
3. `src/config/redis.py` - Configuração do cliente Redis
4. `src/utils/a2a_utils.py` - Utilitários para o protocolo A2A
5. `src/services/a2a_task_manager_service.py` - Gerenciador de tarefas A2A
6. `src/services/a2a_server_service.py` - Servidor A2A
7. `src/services/push_notification_auth_service.py` - Autenticação para push notifications
8. `src/api/a2a_routes.py` - Rotas para o protocolo A2A
### Arquivos a Serem Atualizados
1. `src/main.py` - Registrar novas rotas A2A
2. `pyproject.toml` - Adicionar dependências (Redis, jwcrypto, etc.)
## Plano de Implementação
### Fase 1: Criação dos Esquemas
1. Criar arquivo `src/schemas/a2a.py` com os modelos Pydantic baseados no arquivo `common/types.py`
2. Adaptar os tipos para a estrutura do projeto e adicionar suporte para streaming e push notifications
### Fase 2: Implementação do Serviço de Cache Redis
1. Criar arquivo `src/config/redis.py` para configuração do cliente Redis
2. Criar arquivo `src/services/redis_cache_service.py` para gerenciamento de cache
### Fase 3: Implementação de Utilitários
1. Criar arquivo `src/utils/a2a_utils.py` com funções utilitárias baseadas em `common/server/utils.py`
2. Adaptar o `PushNotificationAuth` para uso no contexto A2A
### Fase 4: Implementação do Gerenciador de Tarefas
1. Criar arquivo `src/services/a2a_task_manager_service.py` com a implementação do `A2ATaskManager`
2. Integrar com serviços existentes:
- agent_runner.py para execução de agentes
- streaming_service.py para streaming SSE
- push_notification_service.py para push notifications
- redis_cache_service.py para cache de tarefas
### Fase 5: Implementação do Servidor A2A
1. Criar arquivo `src/services/a2a_server_service.py` com a implementação do `A2AServer`
2. Implementar processamento de requisições JSON-RPC para todas as operações A2A
### Fase 6: Integração
1. Criar arquivo `src/api/a2a_routes.py` com rotas para o protocolo A2A
2. Registrar as rotas no aplicativo FastAPI principal
3. Assegurar que todas as operações A2A funcionem corretamente, incluindo streaming e push notifications
## Adaptações Necessárias
1. **Esquemas**: Adaptar os modelos do protocolo A2A para usar os esquemas Pydantic existentes quando possível
2. **Autenticação**: Integrar com o sistema de autenticação existente usando API keys
3. **Streaming**: Adaptar o `StreamingService` existente para o formato de eventos A2A
4. **Push Notifications**: Integrar o `PushNotificationService` existente e adicionar suporte a autenticação JWT
5. **Cache**: Utilizar Redis para armazenamento temporário de tarefas e eventos
6. **Execução de Agentes**: Reutilizar o serviço existente `agent_runner.py` para execução
## Próximos Passos
1. Configurar dependências do Redis no projeto
2. Implementar os esquemas em `src/schemas/a2a.py`
3. Implementar o serviço de cache Redis
4. Implementar as funções utilitárias em `src/utils/a2a_utils.py`
5. Implementar o gerenciador de tarefas em `src/services/a2a_task_manager_service.py`
6. Implementar o servidor A2A em `src/services/a2a_server_service.py`
7. Implementar as rotas em `src/api/a2a_routes.py`
8. Registrar as rotas no aplicativo principal `src/main.py`
9. Testar a implementação com casos de uso completos, incluindo streaming e push notifications

View File

@ -10,14 +10,21 @@ services:
ports:
- "8000:8000"
environment:
- POSTGRES_CONNECTION_STRING=postgresql://postgres:postgres@postgres:5432/evo_ai
- REDIS_HOST=redis
- REDIS_PORT=6379
- REDIS_PASSWORD=
- JWT_SECRET_KEY=${JWT_SECRET_KEY}
- SENDGRID_API_KEY=${SENDGRID_API_KEY}
- EMAIL_FROM=${EMAIL_FROM}
- APP_URL=${APP_URL}
POSTGRES_CONNECTION_STRING: postgresql://postgres:postgres@postgres:5432/evo_ai
REDIS_HOST: redis
REDIS_PORT: 6379
REDIS_PASSWORD: ""
REDIS_SSL: "false"
REDIS_KEY_PREFIX: "a2a:"
REDIS_TTL: 3600
A2A_TASK_TTL: 3600
A2A_HISTORY_TTL: 86400
A2A_PUSH_NOTIFICATION_TTL: 3600
A2A_SSE_CLIENT_TTL: 300
JWT_SECRET_KEY: ${JWT_SECRET_KEY}
SENDGRID_API_KEY: ${SENDGRID_API_KEY}
EMAIL_FROM: ${EMAIL_FROM}
APP_URL: ${APP_URL}
volumes:
- ./logs:/app/logs
restart: unless-stopped
@ -26,9 +33,9 @@ services:
image: postgres:14-alpine
container_name: evo-ai-postgres
environment:
- POSTGRES_USER=postgres
- POSTGRES_PASSWORD=postgres
- POSTGRES_DB=evo_ai
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: evo_ai
ports:
- "5432:5432"
volumes:
@ -42,6 +49,12 @@ services:
- "6379:6379"
volumes:
- redis_data:/data
command: redis-server --appendonly yes
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 30s
retries: 50
restart: unless-stopped
volumes:

View File

@ -43,6 +43,10 @@ dependencies = [
"pydantic[email]==2.11.3",
"httpx==0.28.1",
"httpx-sse==0.4.0",
"redis==5.3.0",
"sse-starlette==2.3.3",
"jwcrypto==1.5.6",
"pyjwt[crypto]==2.9.0",
]
[project.optional-dependencies]

395
src/api/a2a_routes.py Normal file
View File

@ -0,0 +1,395 @@
"""
Routes for the A2A (Agent-to-Agent) protocol.
This module implements the standard A2A routes according to the specification.
"""
import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException, status, Header, Request
from sqlalchemy.orm import Session
from starlette.responses import JSONResponse
from src.config.database import get_db
from src.services import agent_service
from src.services import (
RedisCacheService,
AgentRunnerAdapter,
StreamingServiceAdapter,
create_agent_card_from_agent,
)
from src.services.a2a_task_manager_service import A2ATaskManager
from src.services.a2a_server_service import A2AServer
from src.services.agent_runner import run_agent
from src.services.service_providers import (
session_service,
artifacts_service,
memory_service,
)
from src.services.push_notification_service import push_notification_service
from src.services.push_notification_auth_service import push_notification_auth
from src.services.streaming_service import StreamingService
logger = logging.getLogger(__name__)
# Create router with prefix /a2a according to the standard protocol
router = APIRouter(
prefix="/a2a",
tags=["a2a"],
responses={
404: {"description": "Not found"},
400: {"description": "Bad request"},
401: {"description": "Unauthorized"},
500: {"description": "Internal server error"},
},
)
# Singleton instances for shared resources
streaming_service = StreamingService()
redis_cache_service = RedisCacheService()
streaming_adapter = StreamingServiceAdapter(streaming_service)
# Cache dictionary para manter instâncias de A2ATaskManager por agente
# Isso evita criar novas instâncias a cada request
_task_manager_cache = {}
_agent_runner_cache = {}
def get_agent_runner_adapter(db=None, reuse=True, agent_id=None):
"""
Get or create an agent runner adapter.
Args:
db: Database session
reuse: Whether to reuse an existing instance
agent_id: Agent ID to use as cache key
Returns:
Agent runner adapter instance
"""
cache_key = str(agent_id) if agent_id else "default"
logger.info(
f"[DEBUG] get_agent_runner_adapter chamado para agent_id={agent_id}, reuse={reuse}, cache_key={cache_key}"
)
if reuse and cache_key in _agent_runner_cache:
adapter = _agent_runner_cache[cache_key]
logger.info(
f"[DEBUG] Reutilizando agent_runner_adapter existente para {cache_key}"
)
# Atualizar a sessão DB se fornecida
if db is not None:
adapter.db = db
return adapter
logger.info(
f"[IMPORTANTE] Criando NOVA instância de AgentRunnerAdapter para {cache_key}"
)
adapter = AgentRunnerAdapter(
agent_runner_func=run_agent,
session_service=session_service,
artifacts_service=artifacts_service,
memory_service=memory_service,
db=db,
)
if reuse:
logger.info(f"[DEBUG] Armazenando nova instância no cache para {cache_key}")
_agent_runner_cache[cache_key] = adapter
return adapter
def get_task_manager(agent_id, db=None, reuse=True, operation_type="query"):
cache_key = str(agent_id)
# Para operações de consulta, NUNCA crie um agent_runner
if operation_type == "query":
if cache_key in _task_manager_cache:
# Reutilize existente
task_manager = _task_manager_cache[cache_key]
task_manager.db = db
return task_manager
# Se não existe, crie um task_manager SEM agent_runner para consultas
return A2ATaskManager(
redis_cache=redis_cache_service,
agent_runner=None, # Sem agent_runner para consultas!
streaming_service=streaming_adapter,
push_notification_service=push_notification_service,
db=db,
)
# Para operações de execução, use o fluxo normal
if reuse and cache_key in _task_manager_cache:
# Atualize o db
task_manager = _task_manager_cache[cache_key]
task_manager.db = db
return task_manager
# Create new
agent_runner_adapter = get_agent_runner_adapter(
db=db, reuse=reuse, agent_id=agent_id
)
task_manager = A2ATaskManager(
redis_cache=redis_cache_service,
agent_runner=agent_runner_adapter,
streaming_service=streaming_adapter,
push_notification_service=push_notification_service,
db=db,
)
_task_manager_cache[cache_key] = task_manager
return task_manager
@router.post("/{agent_id}")
async def process_a2a_request(
agent_id: uuid.UUID,
request: Request,
x_api_key: str = Header(None, alias="x-api-key"),
db: Session = Depends(get_db),
):
"""
Main endpoint for processing JSON-RPC requests of the A2A protocol.
This endpoint processes all JSON-RPC methods of the A2A protocol, including:
- tasks/send: Sending tasks
- tasks/sendSubscribe: Sending tasks with streaming
- tasks/get: Querying task status
- tasks/cancel: Cancelling tasks
- tasks/pushNotification/set: Setting push notifications
- tasks/pushNotification/get: Querying push notification configurations
- tasks/resubscribe: Resubscribing to receive task updates
Args:
agent_id: Agent ID
request: HTTP request with JSON-RPC payload
x_api_key: API key for authentication
db: Database session
Returns:
JSON-RPC response or streaming (SSE) depending on the method
"""
try:
# Detailed request log
logger.info(f"Request received for A2A agent {agent_id}")
logger.info(f"Headers: {dict(request.headers)}")
try:
body = await request.json()
method = body.get("method", "unknown")
logger.info(f"[IMPORTANTE] Método solicitado: {method}")
logger.info(f"Request body: {body}")
# Determinar se é uma solicitação de consulta (get_task) ou execução (send_task)
is_query_request = method in [
"tasks/get",
"tasks/cancel",
"tasks/pushNotification/get",
"tasks/resubscribe",
]
# Para consultas, reutilizamos os componentes; para execuções,
# criamos novos para garantir estado limpo
reuse_components = is_query_request
logger.info(
f"[IMPORTANTE] Is query request: {is_query_request}, Reuse components: {reuse_components}"
)
except Exception as e:
logger.error(f"Error reading request body: {e}")
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"id": None,
"error": {
"code": -32700,
"message": f"Parse error: {str(e)}",
"data": None,
},
},
)
# Verify if the agent exists
agent = agent_service.get_agent(db, agent_id)
if agent is None:
logger.warning(f"Agent not found: {agent_id}")
return JSONResponse(
status_code=404,
content={
"jsonrpc": "2.0",
"id": None,
"error": {"code": 404, "message": "Agent not found", "data": None},
},
)
# Verify API key
agent_config = agent.config
logger.info(f"Received API Key: {x_api_key}")
logger.info(f"Expected API Key: {agent_config.get('api_key')}")
if x_api_key and agent_config.get("api_key") != x_api_key:
logger.warning(f"Invalid API Key for agent {agent_id}")
return JSONResponse(
status_code=401,
content={
"jsonrpc": "2.0",
"id": None,
"error": {"code": 401, "message": "Invalid API key", "data": None},
},
)
# Obter o task manager para este agente (reutilizando se possível)
a2a_task_manager = get_task_manager(
agent_id,
db=db,
reuse=reuse_components,
operation_type="query" if is_query_request else "execution",
)
a2a_server = A2AServer(task_manager=a2a_task_manager)
# Configure agent_card for the A2A server
logger.info("Configuring agent_card for A2A server")
agent_card = create_agent_card_from_agent(agent, db)
a2a_server.agent_card = agent_card
# Verify JSON-RPC format
if not body.get("jsonrpc") or body.get("jsonrpc") != "2.0":
logger.error(f"Invalid JSON-RPC format: {body.get('jsonrpc')}")
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"id": body.get("id"),
"error": {
"code": -32600,
"message": "Invalid Request: jsonrpc must be '2.0'",
"data": None,
},
},
)
# Verify the method
if not body.get("method"):
logger.error("Method not specified in request")
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"id": body.get("id"),
"error": {
"code": -32600,
"message": "Invalid Request: method is required",
"data": None,
},
},
)
logger.info(f"Processing method: {body.get('method')}")
# Process the request with the A2A server
logger.info("Sending request to A2A server")
# Pass the agent_id and db directly to the process_request method
return await a2a_server.process_request(request, agent_id=str(agent_id), db=db)
except Exception as e:
logger.error(f"Error processing A2A request: {str(e)}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"jsonrpc": "2.0",
"id": None,
"error": {
"code": -32603,
"message": "Internal server error",
"data": {"detail": str(e)},
},
},
)
@router.get("/{agent_id}/.well-known/agent.json")
async def get_agent_card(
agent_id: uuid.UUID,
request: Request,
db: Session = Depends(get_db),
):
"""
Endpoint to get the Agent Card in the .well-known format of the A2A protocol.
This endpoint returns the agent information in the standard A2A format,
including capabilities, authentication information, and skills.
Args:
agent_id: Agent ID
request: HTTP request
db: Database session
Returns:
Agent Card in JSON format
"""
try:
agent = agent_service.get_agent(db, agent_id)
if agent is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found"
)
agent_card = create_agent_card_from_agent(agent, db)
# Obter o task manager para este agente (reutilizando se possível)
a2a_task_manager = get_task_manager(agent_id, db=db, reuse=True)
a2a_server = A2AServer(task_manager=a2a_task_manager)
# Configure the A2A server with the agent card
a2a_server.agent_card = agent_card
# Use the A2A server to deliver the agent card, ensuring protocol compatibility
return await a2a_server.get_agent_card(request, db=db)
except Exception as e:
logger.error(f"Error generating agent card: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error generating agent card",
)
@router.get("/{agent_id}/.well-known/jwks.json")
async def get_jwks(
agent_id: uuid.UUID,
request: Request,
db: Session = Depends(get_db),
):
"""
Endpoint to get the public JWKS keys for verifying the authenticity
of push notifications.
Clients can use these keys to verify the authenticity of received notifications.
Args:
agent_id: Agent ID
request: HTTP request
db: Database session
Returns:
JSON with the public keys in JWKS format
"""
try:
# Verify if the agent exists
agent = agent_service.get_agent(db, agent_id)
if agent is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found"
)
# Return the public keys
return push_notification_auth.handle_jwks_endpoint(request)
except Exception as e:
logger.error(f"Error obtaining JWKS: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error obtaining JWKS",
)

View File

@ -1,6 +1,3 @@
from datetime import datetime
import asyncio
import os
from fastapi import APIRouter, Depends, HTTPException, status, Header
from sqlalchemy.orm import Session
from src.config.database import get_db
@ -18,19 +15,7 @@ from src.services import (
agent_service,
mcp_server_service,
)
from src.services.agent_runner import run_agent
from src.services.service_providers import (
session_service,
artifacts_service,
memory_service,
)
from src.services.push_notification_service import push_notification_service
import logging
from fastapi.responses import StreamingResponse
from ..services.streaming_service import StreamingService
from ..schemas.streaming import JSONRPCRequest
from src.services.session_service import get_session_events
logger = logging.getLogger(__name__)
@ -79,8 +64,6 @@ router = APIRouter(
responses={404: {"description": "Not found"}},
)
streaming_service = StreamingService()
@router.post("/", response_model=Agent, status_code=status.HTTP_201_CREATED)
async def create_agent(
@ -172,348 +155,3 @@ async def delete_agent(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found"
)
@router.get("/{agent_id}/.well-known/agent.json")
async def get_agent_json(
agent_id: uuid.UUID,
db: Session = Depends(get_db),
):
try:
agent = agent_service.get_agent(db, agent_id)
if agent is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found"
)
mcp_servers = agent.config.get("mcp_servers", [])
formatted_tools = await format_agent_tools(mcp_servers, db)
AGENT_CARD = {
"name": agent.name,
"description": agent.description,
"url": f"{os.getenv('API_URL', '')}/api/v1/agents/{agent.id}",
"provider": {
"organization": os.getenv("ORGANIZATION_NAME", ""),
"url": os.getenv("ORGANIZATION_URL", ""),
},
"version": os.getenv("API_VERSION", ""),
"capabilities": {
"streaming": True,
"pushNotifications": True,
"stateTransitionHistory": True,
},
"authentication": {
"schemes": ["apiKey"],
"credentials": {"in": "header", "name": "x-api-key"},
},
"defaultInputModes": ["text", "application/json"],
"defaultOutputModes": ["text", "application/json"],
"skills": formatted_tools,
}
return AGENT_CARD
except Exception as e:
logger.error(f"Error generating agent card: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error generating agent card",
)
@router.post("/{agent_id}/tasks/send")
async def handle_task(
agent_id: uuid.UUID,
request: JSONRPCRequest,
x_api_key: str = Header(..., alias="x-api-key"),
db: Session = Depends(get_db),
):
"""Endpoint to clients A2A send a new task (with an initial user message)."""
try:
# Verify agent
agent = agent_service.get_agent(db, agent_id)
if agent is None:
return {
"jsonrpc": "2.0",
"id": request.id,
"error": {"code": 404, "message": "Agent not found", "data": None},
}
# Verify API key
agent_config = agent.config
if agent_config.get("api_key") != x_api_key:
return {
"jsonrpc": "2.0",
"id": request.id,
"error": {"code": 401, "message": "Invalid API key", "data": None},
}
# Extract task request from JSON-RPC params
task_request = request.params
# Validate required fields
task_id = task_request.get("id")
if not task_id:
return {
"jsonrpc": "2.0",
"id": request.id,
"error": {
"code": -32602,
"message": "Invalid parameters",
"data": {"detail": "Task ID is required"},
},
}
# Extract user message
try:
user_message = task_request["message"]["parts"][0]["text"]
except (KeyError, IndexError):
return {
"jsonrpc": "2.0",
"id": request.id,
"error": {
"code": -32602,
"message": "Invalid parameters",
"data": {"detail": "Invalid message format"},
},
}
# Configure session and metadata
session_id = f"{task_id}_{agent_id}"
metadata = task_request.get("metadata", {})
history_length = metadata.get("historyLength", 50)
# Initialize response
response_task = {
"id": task_id,
"sessionId": session_id,
"status": {
"state": "submitted",
"timestamp": datetime.now().isoformat(),
"message": None,
"error": None,
},
"artifacts": [],
"history": [],
"metadata": metadata,
}
# Handle push notification configuration
push_notification = task_request.get("pushNotification")
if push_notification:
url = push_notification.get("url")
headers = push_notification.get("headers", {})
if not url:
return {
"jsonrpc": "2.0",
"id": request.id,
"error": {
"code": -32602,
"message": "Invalid parameters",
"data": {"detail": "Push notification URL is required"},
},
}
# Store push notification config in metadata
response_task["metadata"]["pushNotification"] = {
"url": url,
"headers": headers,
}
# Send initial notification
asyncio.create_task(
push_notification_service.send_notification(
url=url, task_id=task_id, state="submitted", headers=headers
)
)
try:
# Update status to running
response_task["status"].update(
{"state": "running", "timestamp": datetime.now().isoformat()}
)
# Send running notification if configured
if push_notification:
asyncio.create_task(
push_notification_service.send_notification(
url=url, task_id=task_id, state="running", headers=headers
)
)
# Execute agent
final_response_text = await run_agent(
str(agent_id),
task_id,
user_message,
session_service,
artifacts_service,
memory_service,
db,
session_id,
)
# Update status to completed
response_task["status"].update(
{
"state": "completed",
"timestamp": datetime.now().isoformat(),
"message": {
"role": "agent",
"parts": [{"type": "text", "text": final_response_text}],
},
}
)
# Add artifacts
if final_response_text:
response_task["artifacts"].append(
{
"type": "text",
"content": final_response_text,
"metadata": {
"generated_at": datetime.now().isoformat(),
"content_type": "text/plain",
},
}
)
# Send completed notification if configured
if push_notification:
asyncio.create_task(
push_notification_service.send_notification(
url=url,
task_id=task_id,
state="completed",
message={
"role": "agent",
"parts": [{"type": "text", "text": final_response_text}],
},
headers=headers,
)
)
except Exception as e:
# Update status to failed
response_task["status"].update(
{
"state": "failed",
"timestamp": datetime.now().isoformat(),
"error": {"code": "AGENT_EXECUTION_ERROR", "message": str(e)},
}
)
# Send failed notification if configured
if push_notification:
asyncio.create_task(
push_notification_service.send_notification(
url=url,
task_id=task_id,
state="failed",
message={
"role": "system",
"parts": [{"type": "text", "text": str(e)}],
},
headers=headers,
)
)
# Process history
try:
history_messages = get_session_events(session_service, session_id)
history_messages = history_messages[-history_length:]
formatted_history = []
for event in history_messages:
if event.content and event.content.parts:
role = (
"agent" if event.content.role == "model" else event.content.role
)
formatted_history.append(
{
"role": role,
"parts": [
{"type": "text", "text": part.text}
for part in event.content.parts
if part.text
],
}
)
response_task["history"] = formatted_history
except Exception as e:
logger.error(f"Error processing history: {str(e)}")
# Return JSON-RPC response
return {"jsonrpc": "2.0", "id": request.id, "result": response_task}
except HTTPException as e:
return {
"jsonrpc": "2.0",
"id": request.id,
"error": {"code": e.status_code, "message": e.detail, "data": None},
}
except Exception as e:
logger.error(f"Unexpected error in handle_task: {str(e)}")
return {
"jsonrpc": "2.0",
"id": request.id,
"error": {
"code": -32603,
"message": "Internal server error",
"data": {"detail": str(e)},
},
}
@router.post("/{agent_id}/tasks/sendSubscribe")
async def subscribe_task_streaming(
agent_id: str,
request: JSONRPCRequest,
x_api_key: str = Header(None),
db: Session = Depends(get_db),
):
"""
Endpoint para streaming de eventos SSE de uma tarefa.
Args:
agent_id: ID do agente
request: Requisição JSON-RPC
x_api_key: Chave de API no header
db: Sessão do banco de dados
Returns:
StreamingResponse com eventos SSE
"""
if not x_api_key:
return {
"jsonrpc": "2.0",
"id": request.id,
"error": {"code": 401, "message": "API key é obrigatória", "data": None},
}
# Extrai mensagem do payload
message = request.params.get("message", {}).get("parts", [{}])[0].get("text", "")
session_id = request.params.get("sessionId")
# Configura streaming
async def event_generator():
async for event in streaming_service.send_task_streaming(
agent_id=agent_id,
api_key=x_api_key,
message=message,
session_id=session_id,
db=db,
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)

85
src/config/redis.py Normal file
View File

@ -0,0 +1,85 @@
"""
Redis configuration module.
This module defines the Redis connection settings and provides
function to create a Redis connection pool for the application.
"""
import os
import redis
from dotenv import load_dotenv
import logging
# Load environment variables
load_dotenv()
logger = logging.getLogger(__name__)
def get_redis_config():
"""
Get Redis configuration from environment variables.
Returns:
dict: Redis configuration parameters
"""
return {
"host": os.getenv("REDIS_HOST", "localhost"),
"port": int(os.getenv("REDIS_PORT", 6379)),
"db": int(os.getenv("REDIS_DB", 0)),
"password": os.getenv("REDIS_PASSWORD", None),
"ssl": os.getenv("REDIS_SSL", "false").lower() == "true",
"key_prefix": os.getenv("REDIS_KEY_PREFIX", "a2a:"),
"default_ttl": int(os.getenv("REDIS_TTL", 3600)),
}
def get_a2a_config():
"""
Get A2A-specific cache TTL values from environment variables.
Returns:
dict: A2A TTL configuration parameters
"""
return {
"task_ttl": int(os.getenv("A2A_TASK_TTL", 3600)),
"history_ttl": int(os.getenv("A2A_HISTORY_TTL", 86400)),
"push_notification_ttl": int(os.getenv("A2A_PUSH_NOTIFICATION_TTL", 3600)),
"sse_client_ttl": int(os.getenv("A2A_SSE_CLIENT_TTL", 300)),
}
def create_redis_pool(config=None):
"""
Create and return a Redis connection pool.
Args:
config (dict, optional): Redis configuration. If None,
configuration is loaded from environment
Returns:
redis.ConnectionPool: Redis connection pool
"""
if config is None:
config = get_redis_config()
try:
connection_pool = redis.ConnectionPool(
host=config["host"],
port=config["port"],
db=config["db"],
password=config["password"] if config["password"] else None,
ssl=config["ssl"],
decode_responses=True,
)
# Test the connection
redis_client = redis.Redis(connection_pool=connection_pool)
redis_client.ping()
logger.info(
f"Redis connection successful: {config['host']}:{config['port']}, "
f"db={config['db']}, ssl={config['ssl']}"
)
return connection_pool
except redis.RedisError as e:
logger.error(f"Redis connection error: {e}")
raise

View File

@ -37,6 +37,9 @@ class Settings(BaseSettings):
REDIS_PORT: int = int(os.getenv("REDIS_PORT", 6379))
REDIS_DB: int = int(os.getenv("REDIS_DB", 0))
REDIS_PASSWORD: Optional[str] = os.getenv("REDIS_PASSWORD")
REDIS_SSL: bool = os.getenv("REDIS_SSL", "false").lower() == "true"
REDIS_KEY_PREFIX: str = os.getenv("REDIS_KEY_PREFIX", "evoai:")
REDIS_TTL: int = int(os.getenv("REDIS_TTL", 3600))
# Tool cache TTL in seconds (1 hour)
TOOLS_CACHE_TTL: int = int(os.getenv("TOOLS_CACHE_TTL", 3600))

View File

@ -22,6 +22,7 @@ import src.api.contact_routes
import src.api.mcp_server_routes
import src.api.tool_routes
import src.api.client_routes
import src.api.a2a_routes
# Add the root directory to PYTHONPATH
root_dir = Path(__file__).parent.parent
@ -40,13 +41,13 @@ app = FastAPI(
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Permite todas as origens em desenvolvimento
allow_origins=["*"], # Allows all origins in development
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Configuração de arquivos estáticos
# Static files configuration
static_dir = Path("static")
if not static_dir.exists():
static_dir.mkdir(parents=True)
@ -72,6 +73,7 @@ contact_router = src.api.contact_routes.router
mcp_server_router = src.api.mcp_server_routes.router
tool_router = src.api.tool_routes.router
client_router = src.api.client_routes.router
a2a_router = src.api.a2a_routes.router
# Include routes
app.include_router(auth_router, prefix=API_PREFIX)
@ -83,6 +85,7 @@ app.include_router(chat_router, prefix=API_PREFIX)
app.include_router(session_router, prefix=API_PREFIX)
app.include_router(agent_router, prefix=API_PREFIX)
app.include_router(contact_router, prefix=API_PREFIX)
app.include_router(a2a_router, prefix=API_PREFIX)
@app.get("/")

View File

@ -0,0 +1,9 @@
"""
A2A (Agent-to-Agent) schema package.
This package contains Pydantic schema definitions for the A2A protocol.
"""
from src.schemas.a2a.types import *
from src.schemas.a2a.exceptions import *
from src.schemas.a2a.validators import *

View File

@ -0,0 +1,147 @@
"""
A2A (Agent-to-Agent) protocol exception definitions.
This module contains error types and exceptions for the A2A protocol.
"""
from src.schemas.a2a.types import JSONRPCError
class JSONParseError(JSONRPCError):
"""
Error raised when JSON parsing fails.
"""
code: int = -32700
message: str = "Invalid JSON payload"
data: object | None = None
class InvalidRequestError(JSONRPCError):
"""
Error raised when request validation fails.
"""
code: int = -32600
message: str = "Request payload validation error"
data: object | None = None
class MethodNotFoundError(JSONRPCError):
"""
Error raised when the requested method is not found.
"""
code: int = -32601
message: str = "Method not found"
data: None = None
class InvalidParamsError(JSONRPCError):
"""
Error raised when the parameters are invalid.
"""
code: int = -32602
message: str = "Invalid parameters"
data: object | None = None
class InternalError(JSONRPCError):
"""
Error raised when an internal error occurs.
"""
code: int = -32603
message: str = "Internal error"
data: object | None = None
class TaskNotFoundError(JSONRPCError):
"""
Error raised when the requested task is not found.
"""
code: int = -32001
message: str = "Task not found"
data: None = None
class TaskNotCancelableError(JSONRPCError):
"""
Error raised when a task cannot be canceled.
"""
code: int = -32002
message: str = "Task cannot be canceled"
data: None = None
class PushNotificationNotSupportedError(JSONRPCError):
"""
Error raised when push notifications are not supported.
"""
code: int = -32003
message: str = "Push Notification is not supported"
data: None = None
class UnsupportedOperationError(JSONRPCError):
"""
Error raised when an operation is not supported.
"""
code: int = -32004
message: str = "This operation is not supported"
data: None = None
class ContentTypeNotSupportedError(JSONRPCError):
"""
Error raised when content types are incompatible.
"""
code: int = -32005
message: str = "Incompatible content types"
data: None = None
# Client exceptions
class A2AClientError(Exception):
"""
Base exception for A2A client errors.
"""
pass
class A2AClientHTTPError(A2AClientError):
"""
Exception for HTTP errors in A2A client.
"""
def __init__(self, status_code: int, message: str):
self.status_code = status_code
self.message = message
super().__init__(f"HTTP Error {status_code}: {message}")
class A2AClientJSONError(A2AClientError):
"""
Exception for JSON errors in A2A client.
"""
def __init__(self, message: str):
self.message = message
super().__init__(f"JSON Error: {message}")
class MissingAPIKeyError(Exception):
"""
Exception for missing API key.
"""
pass

464
src/schemas/a2a/types.py Normal file
View File

@ -0,0 +1,464 @@
"""
A2A (Agent-to-Agent) protocol type definitions.
This module contains Pydantic schema definitions for the A2A protocol.
"""
from typing import Union, Any, List, Optional, Annotated, Literal
from pydantic import (
BaseModel,
Field,
TypeAdapter,
field_serializer,
model_validator,
ConfigDict,
)
from datetime import datetime
from uuid import uuid4
from enum import Enum
from typing_extensions import Self
class TaskState(str, Enum):
"""
Enum for the state of a task in the A2A protocol.
States follow the A2A protocol specification.
"""
SUBMITTED = "submitted"
WORKING = "working"
INPUT_REQUIRED = "input-required"
COMPLETED = "completed"
CANCELED = "canceled"
FAILED = "failed"
UNKNOWN = "unknown"
class TextPart(BaseModel):
"""
Represents a text part in a message.
"""
type: Literal["text"] = "text"
text: str
metadata: dict[str, Any] | None = None
class FileContent(BaseModel):
"""
Represents file content in a file part.
Either bytes or uri must be provided, but not both.
"""
name: str | None = None
mimeType: str | None = None
bytes: str | None = None
uri: str | None = None
@model_validator(mode="after")
def check_content(self) -> Self:
"""
Validates that either bytes or uri is present, but not both.
"""
if not (self.bytes or self.uri):
raise ValueError("Either 'bytes' or 'uri' must be present in the file data")
if self.bytes and self.uri:
raise ValueError(
"Only one of 'bytes' or 'uri' can be present in the file data"
)
return self
class FilePart(BaseModel):
"""
Represents a file part in a message.
"""
type: Literal["file"] = "file"
file: FileContent
metadata: dict[str, Any] | None = None
class DataPart(BaseModel):
"""
Represents a data part in a message.
"""
type: Literal["data"] = "data"
data: dict[str, Any]
metadata: dict[str, Any] | None = None
Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator="type")]
class Message(BaseModel):
"""
Represents a message in the A2A protocol.
A message consists of a role and one or more parts.
"""
role: Literal["user", "agent"]
parts: List[Part]
metadata: dict[str, Any] | None = None
class TaskStatus(BaseModel):
"""
Represents the status of a task.
"""
state: TaskState
message: Message | None = None
timestamp: datetime = Field(default_factory=datetime.now)
@field_serializer("timestamp")
def serialize_dt(self, dt: datetime, _info):
"""
Serializes datetime to ISO format.
"""
return dt.isoformat()
class Artifact(BaseModel):
"""
Represents an artifact produced by an agent.
"""
name: str | None = None
description: str | None = None
parts: List[Part]
metadata: dict[str, Any] | None = None
index: int = 0
append: bool | None = None
lastChunk: bool | None = None
class Task(BaseModel):
"""
Represents a task in the A2A protocol.
"""
id: str
sessionId: str | None = None
status: TaskStatus
artifacts: List[Artifact] | None = None
history: List[Message] | None = None
metadata: dict[str, Any] | None = None
class TaskStatusUpdateEvent(BaseModel):
"""
Represents a task status update event.
"""
id: str
status: TaskStatus
final: bool = False
metadata: dict[str, Any] | None = None
class TaskArtifactUpdateEvent(BaseModel):
"""
Represents a task artifact update event.
"""
id: str
artifact: Artifact
metadata: dict[str, Any] | None = None
class AuthenticationInfo(BaseModel):
"""
Represents authentication information for push notifications.
"""
model_config = ConfigDict(extra="allow")
schemes: List[str]
credentials: str | None = None
class PushNotificationConfig(BaseModel):
"""
Represents push notification configuration.
"""
url: str
token: str | None = None
authentication: AuthenticationInfo | None = None
class TaskIdParams(BaseModel):
"""
Represents parameters for identifying a task.
"""
id: str
metadata: dict[str, Any] | None = None
class TaskQueryParams(TaskIdParams):
"""
Represents parameters for querying a task.
"""
historyLength: int | None = None
class TaskSendParams(BaseModel):
"""
Represents parameters for sending a task.
"""
id: str
sessionId: str = Field(default_factory=lambda: uuid4().hex)
message: Message
acceptedOutputModes: Optional[List[str]] = None
pushNotification: PushNotificationConfig | None = None
historyLength: int | None = None
metadata: dict[str, Any] | None = None
agentId: str = ""
class TaskPushNotificationConfig(BaseModel):
"""
Represents push notification configuration for a task.
"""
id: str
pushNotificationConfig: PushNotificationConfig
# RPC Messages
class JSONRPCMessage(BaseModel):
"""
Base class for JSON-RPC messages.
"""
jsonrpc: Literal["2.0"] = "2.0"
id: int | str | None = Field(default_factory=lambda: uuid4().hex)
class JSONRPCRequest(JSONRPCMessage):
"""
Represents a JSON-RPC request.
"""
method: str
params: dict[str, Any] | None = None
class JSONRPCError(BaseModel):
"""
Represents a JSON-RPC error.
"""
code: int
message: str
data: Any | None = None
class JSONRPCResponse(JSONRPCMessage):
"""
Represents a JSON-RPC response.
"""
result: Any | None = None
error: JSONRPCError | None = None
class SendTaskRequest(JSONRPCRequest):
"""
Represents a request to send a task.
"""
method: Literal["tasks/send"] = "tasks/send"
params: TaskSendParams
class SendTaskResponse(JSONRPCResponse):
"""
Represents a response to a send task request.
"""
result: Task | None = None
class SendTaskStreamingRequest(JSONRPCRequest):
"""
Represents a request to send a task with streaming.
"""
method: Literal["tasks/sendSubscribe"] = "tasks/sendSubscribe"
params: TaskSendParams
class SendTaskStreamingResponse(JSONRPCResponse):
"""
Represents a streaming response to a send task request.
"""
result: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None
class GetTaskRequest(JSONRPCRequest):
"""
Represents a request to get task information.
"""
method: Literal["tasks/get"] = "tasks/get"
params: TaskQueryParams
class GetTaskResponse(JSONRPCResponse):
"""
Represents a response to a get task request.
"""
result: Task | None = None
class CancelTaskRequest(JSONRPCRequest):
"""
Represents a request to cancel a task.
"""
method: Literal["tasks/cancel",] = "tasks/cancel"
params: TaskIdParams
class CancelTaskResponse(JSONRPCResponse):
"""
Represents a response to a cancel task request.
"""
result: Task | None = None
class SetTaskPushNotificationRequest(JSONRPCRequest):
"""
Represents a request to set push notification for a task.
"""
method: Literal["tasks/pushNotification/set",] = "tasks/pushNotification/set"
params: TaskPushNotificationConfig
class SetTaskPushNotificationResponse(JSONRPCResponse):
"""
Represents a response to a set push notification request.
"""
result: TaskPushNotificationConfig | None = None
class GetTaskPushNotificationRequest(JSONRPCRequest):
"""
Represents a request to get push notification configuration for a task.
"""
method: Literal["tasks/pushNotification/get",] = "tasks/pushNotification/get"
params: TaskIdParams
class GetTaskPushNotificationResponse(JSONRPCResponse):
"""
Represents a response to a get push notification request.
"""
result: TaskPushNotificationConfig | None = None
class TaskResubscriptionRequest(JSONRPCRequest):
"""
Represents a request to resubscribe to a task.
"""
method: Literal["tasks/resubscribe",] = "tasks/resubscribe"
params: TaskIdParams
# TypeAdapter for discriminating A2A requests by method
A2ARequest = TypeAdapter(
Annotated[
Union[
SendTaskRequest,
GetTaskRequest,
CancelTaskRequest,
SetTaskPushNotificationRequest,
GetTaskPushNotificationRequest,
TaskResubscriptionRequest,
SendTaskStreamingRequest,
],
Field(discriminator="method"),
]
)
# Agent Card schemas
class AgentProvider(BaseModel):
"""
Represents the provider of an agent.
"""
organization: str
url: str | None = None
class AgentCapabilities(BaseModel):
"""
Represents the capabilities of an agent.
"""
streaming: bool = False
pushNotifications: bool = False
stateTransitionHistory: bool = False
class AgentAuthentication(BaseModel):
"""
Represents the authentication requirements for an agent.
"""
schemes: List[str]
credentials: str | None = None
class AgentSkill(BaseModel):
"""
Represents a skill of an agent.
"""
id: str
name: str
description: str | None = None
tags: List[str] | None = None
examples: List[str] | None = None
inputModes: List[str] | None = None
outputModes: List[str] | None = None
class AgentCard(BaseModel):
"""
Represents an agent card in the A2A protocol.
"""
name: str
description: str | None = None
url: str
provider: AgentProvider | None = None
version: str
documentationUrl: str | None = None
capabilities: AgentCapabilities
authentication: AgentAuthentication | None = None
defaultInputModes: List[str] = ["text"]
defaultOutputModes: List[str] = ["text"]
skills: List[AgentSkill]

View File

@ -0,0 +1,124 @@
"""
A2A (Agent-to-Agent) protocol validators.
This module contains validators for the A2A protocol data.
"""
from typing import List
import base64
import re
from pydantic import ValidationError
import logging
from src.schemas.a2a.types import Part, TextPart, FilePart, DataPart, FileContent
logger = logging.getLogger(__name__)
def validate_base64(value: str) -> bool:
"""
Validates if a string is valid base64.
Args:
value: String to validate
Returns:
True if valid base64, False otherwise
"""
try:
if not value:
return False
# Check if the string has base64 characters only
pattern = r"^[A-Za-z0-9+/]+={0,2}$"
if not re.match(pattern, value):
return False
# Try to decode
base64.b64decode(value)
return True
except Exception as e:
logger.warning(f"Invalid base64 string: {e}")
return False
def validate_file_content(file_content: FileContent) -> bool:
"""
Validates file content.
Args:
file_content: FileContent to validate
Returns:
True if valid, False otherwise
"""
try:
if file_content.bytes is not None:
return validate_base64(file_content.bytes)
elif file_content.uri is not None:
# Basic URL validation
pattern = r"^https?://.+"
return bool(re.match(pattern, file_content.uri))
return False
except Exception as e:
logger.warning(f"Invalid file content: {e}")
return False
def validate_message_parts(parts: List[Part]) -> bool:
"""
Validates all parts in a message.
Args:
parts: List of parts to validate
Returns:
True if all parts are valid, False otherwise
"""
try:
for part in parts:
if isinstance(part, TextPart):
if not part.text or not isinstance(part.text, str):
return False
elif isinstance(part, FilePart):
if not validate_file_content(part.file):
return False
elif isinstance(part, DataPart):
if not part.data or not isinstance(part.data, dict):
return False
else:
return False
return True
except (ValidationError, Exception) as e:
logger.warning(f"Invalid message parts: {e}")
return False
def text_to_parts(text: str) -> List[Part]:
"""
Converts a plain text to a list of message parts.
Args:
text: Plain text to convert
Returns:
List containing a single TextPart
"""
return [TextPart(text=text)]
def parts_to_text(parts: List[Part]) -> str:
"""
Extracts text from a list of message parts.
Args:
parts: List of parts to extract text from
Returns:
Concatenated text from all text parts
"""
text = ""
for part in parts:
if isinstance(part, TextPart):
text += part.text
# Could add handling for other part types here
return text

View File

@ -1 +1,9 @@
from .agent_runner import run_agent
from .redis_cache_service import RedisCacheService
from .a2a_task_manager_service import A2ATaskManager
from .a2a_server_service import A2AServer
from .a2a_integration_service import (
AgentRunnerAdapter,
StreamingServiceAdapter,
create_agent_card_from_agent,
)

View File

@ -0,0 +1,520 @@
"""
A2A Integration Service.
This service provides adapters to integrate existing services with the A2A protocol.
"""
import json
import logging
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional, AsyncIterable
from src.schemas.a2a import (
AgentCard,
AgentCapabilities,
AgentProvider,
Artifact,
Message,
TaskArtifactUpdateEvent,
TaskStatus,
TaskStatusUpdateEvent,
TextPart,
)
logger = logging.getLogger(__name__)
class AgentRunnerAdapter:
"""
Adapter for integrating the existing agent runner with the A2A protocol.
"""
def __init__(
self,
agent_runner_func,
session_service=None,
artifacts_service=None,
memory_service=None,
db=None,
):
"""
Initialize the adapter.
Args:
agent_runner_func: The agent runner function (e.g., run_agent)
session_service: Session service for message history
artifacts_service: Artifacts service for artifact history
memory_service: Memory service for agent memory
db: Database session
"""
self.agent_runner_func = agent_runner_func
self.session_service = session_service
self.artifacts_service = artifacts_service
self.memory_service = memory_service
self.db = db
async def get_supported_modes(self) -> List[str]:
"""
Get the supported output modes for the agent.
Returns:
List of supported output modes
"""
# Default modes, can be extended based on agent configuration
return ["text", "application/json"]
async def run_agent(
self,
agent_id: str,
message: str,
session_id: Optional[str] = None,
task_id: Optional[str] = None,
db=None,
) -> Dict[str, Any]:
"""
Run the agent with the given message.
Args:
agent_id: ID of the agent to run
message: User message to process
session_id: Optional session ID for conversation context
task_id: Optional task ID for tracking
db: Database session
Returns:
Dictionary with the agent's response
"""
logger.info(
f"[AGENT-RUNNER] run_agent iniciado - agent_id={agent_id}, task_id={task_id}, session_id={session_id}"
)
logger.info(
f"[AGENT-RUNNER] run_agent - message: '{message[:50]}...' (truncado)"
)
try:
# Use the existing agent runner function
session_id = session_id or str(uuid.uuid4())
task_id = task_id or str(uuid.uuid4())
# Use the provided db or fallback to self.db
db_session = db if db is not None else self.db
if db_session is None:
logger.error(
f"[AGENT-RUNNER] No database session available. db={db}, self.db={self.db}"
)
else:
logger.info(
f"[AGENT-RUNNER] Using database session: {type(db_session).__name__}"
)
logger.info(
f"[AGENT-RUNNER] Chamando agent_runner_func com agent_id={agent_id}, contact_id={task_id}"
)
response_text = await self.agent_runner_func(
agent_id=agent_id,
contact_id=task_id,
message=message,
session_service=self.session_service,
artifacts_service=self.artifacts_service,
memory_service=self.memory_service,
db=db_session,
session_id=session_id,
)
logger.info(
f"[AGENT-RUNNER] run_agent concluído com sucesso para agent_id={agent_id}, task_id={task_id}"
)
logger.info(
f"[AGENT-RUNNER] resposta: '{str(response_text)[:50]}...' (truncado)"
)
return {
"status": "success",
"content": response_text,
"session_id": session_id,
"task_id": task_id,
"timestamp": datetime.now().isoformat(),
}
except Exception as e:
logger.error(f"[AGENT-RUNNER] Error running agent: {e}", exc_info=True)
return {
"status": "error",
"error": str(e),
"session_id": session_id,
"task_id": task_id,
"timestamp": datetime.now().isoformat(),
}
async def cancel_task(self, task_id: str) -> bool:
"""
Cancel a running task.
Args:
task_id: ID of the task to cancel
Returns:
True if successfully canceled, False otherwise
"""
# Currently, the agent runner doesn't support cancellation
# This is a placeholder for future implementation
logger.warning(f"Task cancellation not implemented for task {task_id}")
return False
class StreamingServiceAdapter:
"""
Adapter for integrating the existing streaming service with the A2A protocol.
"""
def __init__(self, streaming_service):
"""
Initialize the adapter.
Args:
streaming_service: The streaming service instance
"""
self.streaming_service = streaming_service
async def stream_agent_response(
self,
agent_id: str,
message: str,
api_key: str,
session_id: Optional[str] = None,
task_id: Optional[str] = None,
db=None,
) -> AsyncIterable[str]:
"""
Stream the agent's response as A2A events.
Args:
agent_id: ID of the agent
message: User message to process
api_key: API key for authentication
session_id: Optional session ID for conversation context
task_id: Optional task ID for tracking
db: Database session
Yields:
A2A event objects as JSON strings for SSE (Server-Sent Events)
"""
task_id = task_id or str(uuid.uuid4())
logger.info(f"Starting streaming response for task {task_id}")
# Set working status event
working_status = TaskStatus(
state="working",
timestamp=datetime.now(),
message=Message(
role="agent", parts=[TextPart(text="Processing your request...")]
),
)
status_event = TaskStatusUpdateEvent(
id=task_id, status=working_status, final=False
)
# IMPORTANTE: Converter para string JSON para SSE
yield json.dumps(status_event.model_dump())
content_buffer = ""
final_sent = False
has_error = False
# Stream from the existing streaming service
try:
logger.info(f"Setting up streaming for agent {agent_id}, task {task_id}")
# To streaming, we use task_id as contact_id
contact_id = task_id
# Adicionar tratamento de heartbeat para manter conexão ativa
last_event_time = datetime.now()
heartbeat_interval = 20 # segundos
async for event in self.streaming_service.send_task_streaming(
agent_id=agent_id,
api_key=api_key,
message=message,
contact_id=contact_id,
session_id=session_id,
db=db,
):
# Atualizar timestamp do último evento
last_event_time = datetime.now()
# Process the streaming event format
event_data = event.get("data", "{}")
try:
logger.info(f"Processing event data: {event_data[:100]}...")
data = json.loads(event_data)
# Extract content
if "delta" in data and data["delta"].get("content"):
content = data["delta"]["content"]
content_buffer += content
logger.info(f"Received content chunk: {content[:50]}...")
# Create artifact update event
artifact = Artifact(
name="response",
parts=[TextPart(text=content)],
index=0,
append=True,
lastChunk=False,
)
artifact_event = TaskArtifactUpdateEvent(
id=task_id, artifact=artifact
)
# IMPORTANTE: Converter para string JSON para SSE
yield json.dumps(artifact_event.model_dump())
# Check if final event
if data.get("done", False) and not final_sent:
logger.info(f"Received final event for task {task_id}")
# Create completed status event
completed_status = TaskStatus(
state="completed",
timestamp=datetime.now(),
message=Message(
role="agent",
parts=[
TextPart(text=content_buffer or "Task completed")
],
),
)
# Final artifact with full content
final_artifact = Artifact(
name="response",
parts=[TextPart(text=content_buffer)],
index=0,
append=False,
lastChunk=True,
)
# Send the final artifact
final_artifact_event = TaskArtifactUpdateEvent(
id=task_id, artifact=final_artifact
)
# IMPORTANTE: Converter para string JSON para SSE
yield json.dumps(final_artifact_event.model_dump())
# Send the completed status
final_status_event = TaskStatusUpdateEvent(
id=task_id,
status=completed_status,
final=True,
)
# IMPORTANTE: Converter para string JSON para SSE
yield json.dumps(final_status_event.model_dump())
final_sent = True
except json.JSONDecodeError as e:
logger.warning(
f"Received non-JSON event data: {e}. Data: {event_data[:100]}..."
)
# Handle non-JSON events - simply add to buffer as text
if isinstance(event_data, str):
content_buffer += event_data
# Create artifact update event
artifact = Artifact(
name="response",
parts=[TextPart(text=event_data)],
index=0,
append=True,
lastChunk=False,
)
artifact_event = TaskArtifactUpdateEvent(
id=task_id, artifact=artifact
)
# IMPORTANTE: Converter para string JSON para SSE
yield json.dumps(artifact_event.model_dump())
elif isinstance(event_data, dict):
# Try to extract text from the dictionary
text_value = str(event_data)
content_buffer += text_value
artifact = Artifact(
name="response",
parts=[TextPart(text=text_value)],
index=0,
append=True,
lastChunk=False,
)
artifact_event = TaskArtifactUpdateEvent(
id=task_id, artifact=artifact
)
# IMPORTANTE: Converter para string JSON para SSE
yield json.dumps(artifact_event.model_dump())
# Enviar heartbeat/keep-alive para manter a conexão SSE aberta
now = datetime.now()
if (now - last_event_time).total_seconds() > heartbeat_interval:
logger.info(f"Sending heartbeat for task {task_id}")
# Enviando evento de keep-alive como um evento de status de "working"
working_heartbeat = TaskStatus(
state="working",
timestamp=now,
message=Message(
role="agent", parts=[TextPart(text="Still processing...")]
),
)
heartbeat_event = TaskStatusUpdateEvent(
id=task_id, status=working_heartbeat, final=False
)
# IMPORTANTE: Converter para string JSON para SSE
yield json.dumps(heartbeat_event.model_dump())
last_event_time = now
# Ensure we send a final event if not already sent
if not final_sent:
logger.info(
f"Stream completed for task {task_id}, sending final status"
)
# Create completed status event
completed_status = TaskStatus(
state="completed",
timestamp=datetime.now(),
message=Message(
role="agent",
parts=[TextPart(text=content_buffer or "Task completed")],
),
)
# Send the completed status
final_event = TaskStatusUpdateEvent(
id=task_id, status=completed_status, final=True
)
# IMPORTANTE: Converter para string JSON para SSE
yield json.dumps(final_event.model_dump())
except Exception as e:
has_error = True
logger.error(f"Error in streaming for task {task_id}: {e}", exc_info=True)
# Create failed status event
failed_status = TaskStatus(
state="failed",
timestamp=datetime.now(),
message=Message(
role="agent",
parts=[
TextPart(
text=f"Error during streaming: {str(e)}. Partial response: {content_buffer[:200] if content_buffer else 'No content received'}"
)
],
),
)
error_event = TaskStatusUpdateEvent(
id=task_id, status=failed_status, final=True
)
# IMPORTANTE: Converter para string JSON para SSE
yield json.dumps(error_event.model_dump())
finally:
# Garantir que enviamos um evento final para fechar a conexão corretamente
if not final_sent and not has_error:
logger.info(f"Stream finalizing for task {task_id} via finally block")
try:
# Create completed status event
completed_status = TaskStatus(
state="completed",
timestamp=datetime.now(),
message=Message(
role="agent",
parts=[
TextPart(
text=content_buffer or "Task completed (forced end)"
)
],
),
)
# Send the completed status
final_event = TaskStatusUpdateEvent(
id=task_id, status=completed_status, final=True
)
# IMPORTANTE: Converter para string JSON para SSE
yield json.dumps(final_event.model_dump())
except Exception as final_error:
logger.error(
f"Error sending final event in finally block: {final_error}"
)
logger.info(f"Streaming completed for task {task_id}")
def create_agent_card_from_agent(agent, db) -> AgentCard:
"""
Create an A2A agent card from an agent model.
Args:
agent: The agent model from the database
db: Database session
Returns:
A2A AgentCard object
"""
import os
from src.api.agent_routes import format_agent_tools
import asyncio
# Extract agent configuration
agent_config = agent.config
has_streaming = True # Assuming streaming is always supported
has_push = True # Assuming push notifications are supported
# Format tools as skills
try:
# We use a different approach to handle the asynchronous function
mcp_servers = agent_config.get("mcp_servers", [])
# We create a new thread to execute the asynchronous function
import concurrent.futures
import functools
def run_async(coro):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(coro)
loop.close()
return result
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_async, format_agent_tools(mcp_servers, db))
skills = future.result()
except Exception as e:
logger.error(f"Error formatting agent tools: {e}")
skills = []
# Create agent card
return AgentCard(
name=agent.name,
description=agent.description,
url=f"{os.getenv('API_URL', '')}/api/v1/a2a/{agent.id}",
provider=AgentProvider(
organization=os.getenv("ORGANIZATION_NAME", ""),
url=os.getenv("ORGANIZATION_URL", ""),
),
version=os.getenv("API_VERSION", "1.0.0"),
capabilities=AgentCapabilities(
streaming=has_streaming,
pushNotifications=has_push,
stateTransitionHistory=True,
),
authentication={
"schemes": ["apiKey"],
"credentials": "x-api-key",
},
defaultInputModes=["text", "application/json"],
defaultOutputModes=["text", "application/json"],
skills=skills,
)

View File

@ -0,0 +1,739 @@
"""
Server A2A and task manager for the A2A protocol.
This module implements a JSON-RPC compatible server for the A2A protocol,
that manages agent tasks, streaming events and push notifications.
"""
import asyncio
import json
import logging
import uuid
from datetime import datetime
from typing import (
Any,
Dict,
List,
Optional,
AsyncGenerator,
Callable,
Union,
AsyncIterable,
)
import httpx
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse, Response
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from src.schemas.a2a.types import A2ARequest
from src.services.agent_runner import run_agent
from src.services.a2a_integration_service import (
AgentRunnerAdapter,
StreamingServiceAdapter,
)
from src.services.session_service import get_session_events
from src.services.redis_cache_service import RedisCacheService
from src.schemas.a2a.types import (
SendTaskRequest,
SendTaskStreamingRequest,
GetTaskRequest,
CancelTaskRequest,
SetTaskPushNotificationRequest,
GetTaskPushNotificationRequest,
TaskResubscriptionRequest,
TaskSendParams,
)
from src.utils.a2a_utils import are_modalities_compatible
logger = logging.getLogger(__name__)
class A2ATaskManager:
"""
Task manager for the A2A protocol.
This class manages the lifecycle of tasks, including:
- Task execution
- Streaming of events
- Push notifications
- Status querying
- Cancellation
"""
def __init__(
self,
redis_cache: RedisCacheService,
agent_runner: AgentRunnerAdapter,
streaming_service: StreamingServiceAdapter,
push_notification_service: Any = None,
):
"""
Initialize the task manager.
Args:
redis_cache: Cache service for storing task data
agent_runner: Adapter for agent execution
streaming_service: Adapter for event streaming
push_notification_service: Service for sending push notifications
"""
self.cache = redis_cache
self.agent_runner = agent_runner
self.streaming_service = streaming_service
self.push_notification_service = push_notification_service
self._running_tasks = {}
async def on_send_task(
self,
task_id: str,
agent_id: str,
message: Dict[str, Any],
session_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
input_mode: str = "text",
output_modes: List[str] = ["text"],
db: Optional[Session] = None,
) -> Dict[str, Any]:
"""
Process a request to send a task.
Args:
task_id: Task ID
agent_id: Agent ID
message: User message
session_id: Session ID (optional)
metadata: Additional metadata (optional)
input_mode: Input mode (text, JSON, etc.)
output_modes: Supported output modes
db: Database session
Returns:
Response with task result
"""
if not session_id:
session_id = f"{task_id}_{agent_id}"
if not metadata:
metadata = {}
# Update status to "submitted"
task_data = {
"id": task_id,
"sessionId": session_id,
"status": {
"state": "submitted",
"timestamp": datetime.now().isoformat(),
"message": None,
"error": None,
},
"artifacts": [],
"history": [],
"metadata": metadata,
}
# Store initial task data
await self.cache.set(f"task:{task_id}", task_data)
# Check for push notification configurations
push_config = await self.cache.get(f"task:{task_id}:push")
if push_config and self.push_notification_service:
# Send initial notification
await self.push_notification_service.send_notification(
url=push_config["url"],
task_id=task_id,
state="submitted",
headers=push_config.get("headers", {}),
)
try:
# Update status to "running"
task_data["status"].update(
{"state": "running", "timestamp": datetime.now().isoformat()}
)
await self.cache.set(f"task:{task_id}", task_data)
# Notify "running" state
if push_config and self.push_notification_service:
await self.push_notification_service.send_notification(
url=push_config["url"],
task_id=task_id,
state="running",
headers=push_config.get("headers", {}),
)
# Extract user message
user_message = None
try:
user_message = message["parts"][0]["text"]
except (KeyError, IndexError):
user_message = ""
# Execute the agent
response = await self.agent_runner.run_agent(
agent_id=agent_id,
task_id=task_id,
message=user_message,
session_id=session_id,
db=db,
)
# Check if the response is a dictionary (error) or a string (success)
if isinstance(response, dict) and response.get("status") == "error":
# Error response
final_response = f"Error: {response.get('error', 'Unknown error')}"
# Update status to "failed"
task_data["status"].update(
{
"state": "failed",
"timestamp": datetime.now().isoformat(),
"error": {
"code": "AGENT_EXECUTION_ERROR",
"message": response.get("error", "Unknown error"),
},
"message": {
"role": "system",
"parts": [{"type": "text", "text": final_response}],
},
}
)
# Notify "failed" state
if push_config and self.push_notification_service:
await self.push_notification_service.send_notification(
url=push_config["url"],
task_id=task_id,
state="failed",
message={
"role": "system",
"parts": [{"type": "text", "text": final_response}],
},
headers=push_config.get("headers", {}),
)
else:
# Success response
final_response = (
response.get("content") if isinstance(response, dict) else response
)
# Update status to "completed"
task_data["status"].update(
{
"state": "completed",
"timestamp": datetime.now().isoformat(),
"message": {
"role": "agent",
"parts": [{"type": "text", "text": final_response}],
},
}
)
# Add artifacts
if final_response:
task_data["artifacts"].append(
{
"type": "text",
"content": final_response,
"metadata": {
"generated_at": datetime.now().isoformat(),
"content_type": "text/plain",
},
}
)
# Add history of messages
history_length = metadata.get("historyLength", 50)
try:
history_messages = get_session_events(
self.agent_runner.session_service, session_id
)
history_messages = history_messages[-history_length:]
formatted_history = []
for event in history_messages:
if event.content and event.content.parts:
role = (
"agent"
if event.content.role == "model"
else event.content.role
)
formatted_history.append(
{
"role": role,
"parts": [
{"type": "text", "text": part.text}
for part in event.content.parts
if part.text
],
}
)
task_data["history"] = formatted_history
except Exception as e:
logger.error(f"Error processing history: {str(e)}")
# Notify "completed" state
if push_config and self.push_notification_service:
await self.push_notification_service.send_notification(
url=push_config["url"],
task_id=task_id,
state="completed",
message={
"role": "agent",
"parts": [{"type": "text", "text": final_response}],
},
headers=push_config.get("headers", {}),
)
except Exception as e:
logger.error(f"Error executing task {task_id}: {str(e)}")
# Update status to "failed"
task_data["status"].update(
{
"state": "failed",
"timestamp": datetime.now().isoformat(),
"error": {"code": "AGENT_EXECUTION_ERROR", "message": str(e)},
}
)
# Notify "failed" state
if push_config and self.push_notification_service:
await self.push_notification_service.send_notification(
url=push_config["url"],
task_id=task_id,
state="failed",
message={
"role": "system",
"parts": [{"type": "text", "text": str(e)}],
},
headers=push_config.get("headers", {}),
)
# Store final result
await self.cache.set(f"task:{task_id}", task_data)
return task_data
async def on_send_task_subscribe(
self,
task_id: str,
agent_id: str,
message: Dict[str, Any],
session_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
input_mode: str = "text",
output_modes: List[str] = ["text"],
db: Optional[Session] = None,
) -> AsyncGenerator[str, None]:
"""
Process a request to send a task with streaming.
Args:
task_id: Task ID
agent_id: Agent ID
message: User message
session_id: Session ID (optional)
metadata: Additional metadata (optional)
input_mode: Input mode (text, JSON, etc.)
output_modes: Supported output modes
db: Database session
Yields:
Streaming events in SSE (Server-Sent Events) format
"""
if not session_id:
session_id = f"{task_id}_{agent_id}"
if not metadata:
metadata = {}
# Extract user message
user_message = ""
try:
user_message = message["parts"][0]["text"]
except (KeyError, IndexError):
pass
# Generate streaming events
async for event in self.streaming_service.stream_response(
agent_id=agent_id,
task_id=task_id,
message=user_message,
session_id=session_id,
db=db,
):
yield event
async def on_get_task(self, task_id: str) -> Dict[str, Any]:
"""
Query the status of a task by ID.
Args:
task_id: Task ID
Returns:
Current task status
Raises:
Exception: If the task is not found
"""
task_data = await self.cache.get(f"task:{task_id}")
if not task_data:
raise Exception(f"Task {task_id} not found")
return task_data
async def on_cancel_task(self, task_id: str) -> Dict[str, Any]:
"""
Cancel a running task.
Args:
task_id: Task ID to be cancelled
Returns:
Task status after cancellation
Raises:
Exception: If the task is not found or cannot be cancelled
"""
task_data = await self.cache.get(f"task:{task_id}")
if not task_data:
raise Exception(f"Task {task_id} not found")
# Check if the task is in a state that can be cancelled
current_state = task_data["status"]["state"]
if current_state not in ["submitted", "running"]:
raise Exception(f"Cannot cancel task in {current_state} state")
# Cancel the task in the runner if it is running
running_task = self._running_tasks.get(task_id)
if running_task:
# Try to cancel the running task
if hasattr(running_task, "cancel"):
running_task.cancel()
# Update status to "cancelled"
task_data["status"].update(
{
"state": "cancelled",
"timestamp": datetime.now().isoformat(),
"message": {
"role": "system",
"parts": [{"type": "text", "text": "Task cancelled by user"}],
},
}
)
# Update cache
await self.cache.set(f"task:{task_id}", task_data)
# Send push notification if configured
push_config = await self.cache.get(f"task:{task_id}:push")
if push_config and self.push_notification_service:
await self.push_notification_service.send_notification(
url=push_config["url"],
task_id=task_id,
state="cancelled",
message={
"role": "system",
"parts": [{"type": "text", "text": "Task cancelled by user"}],
},
headers=push_config.get("headers", {}),
)
return task_data
async def on_set_task_push_notification(
self, task_id: str, notification_config: Dict[str, Any]
) -> Dict[str, Any]:
"""
Configure push notifications for a task.
Args:
task_id: Task ID
notification_config: Notification configuration (URL and headers)
Returns:
Updated configuration
"""
# Validate configuration
url = notification_config.get("url")
if not url:
raise ValueError("Push notification URL is required")
headers = notification_config.get("headers", {})
# Store configuration
config = {"url": url, "headers": headers}
await self.cache.set(f"task:{task_id}:push", config)
return config
async def on_get_task_push_notification(self, task_id: str) -> Dict[str, Any]:
"""
Get the push notification configuration for a task.
Args:
task_id: Task ID
Returns:
Push notification configuration
Raises:
Exception: If there is no configuration for the task
"""
config = await self.cache.get(f"task:{task_id}:push")
if not config:
raise Exception(f"No push notification configuration for task {task_id}")
return config
class A2AServer:
"""
A2A server compatible with JSON-RPC 2.0.
This class processes JSON-RPC requests and forwards them to
the appropriate handlers in the A2ATaskManager.
"""
def __init__(self, task_manager: A2ATaskManager, agent_card=None):
"""
Initialize the A2A server.
Args:
task_manager: Task manager
agent_card: Agent card information
"""
self.task_manager = task_manager
self.agent_card = agent_card
async def process_request(
self,
request: Request,
agent_id: Optional[str] = None,
db: Optional[Session] = None,
) -> Union[Response, JSONResponse, StreamingResponse]:
"""
Process a JSON-RPC request.
Args:
request: HTTP request
agent_id: Optional agent ID to inject into the request
db: Database session
Returns:
Appropriate response (JSON or Streaming)
"""
try:
# Try to parse the JSON payload
try:
logger.info("Starting JSON-RPC request processing")
body = await request.json()
logger.info(f"Received JSON data: {json.dumps(body)}")
method = body.get("method", "unknown")
logger.info(f"[SERVER] Processando método: {method}")
# Validate the request using the A2A validator
json_rpc_request = A2ARequest.validate_python(body)
logger.info(
f"[SERVER] Request validado como: {type(json_rpc_request).__name__}"
)
original_db = self.task_manager.db
try:
# Set the db temporarily
if db is not None:
self.task_manager.db = db
# Process the request
if isinstance(json_rpc_request, SendTaskRequest):
logger.info(
f"[SERVER] Processando SendTaskRequest para task_id={json_rpc_request.params.id}"
)
json_rpc_request.params.agentId = agent_id
result = await self.task_manager.on_send_task(json_rpc_request)
elif isinstance(json_rpc_request, SendTaskStreamingRequest):
logger.info(
f"[SERVER] Processando SendTaskStreamingRequest para task_id={json_rpc_request.params.id}"
)
json_rpc_request.params.agentId = agent_id
result = await self.task_manager.on_send_task_subscribe(
json_rpc_request
)
elif isinstance(json_rpc_request, GetTaskRequest):
logger.info(
f"[SERVER] Processando GetTaskRequest para task_id={json_rpc_request.params.id}"
)
result = await self.task_manager.on_get_task(json_rpc_request)
elif isinstance(json_rpc_request, CancelTaskRequest):
logger.info(
f"[SERVER] Processando CancelTaskRequest para task_id={json_rpc_request.params.id}"
)
result = await self.task_manager.on_cancel_task(
json_rpc_request
)
elif isinstance(json_rpc_request, SetTaskPushNotificationRequest):
logger.info(
f"[SERVER] Processando SetTaskPushNotificationRequest para task_id={json_rpc_request.params.id}"
)
result = await self.task_manager.on_set_task_push_notification(
json_rpc_request
)
elif isinstance(json_rpc_request, GetTaskPushNotificationRequest):
logger.info(
f"[SERVER] Processando GetTaskPushNotificationRequest para task_id={json_rpc_request.params.id}"
)
result = await self.task_manager.on_get_task_push_notification(
json_rpc_request
)
elif isinstance(json_rpc_request, TaskResubscriptionRequest):
logger.info(
f"[SERVER] Processando TaskResubscriptionRequest para task_id={json_rpc_request.params.id}"
)
result = await self.task_manager.on_resubscribe_to_task(
json_rpc_request
)
else:
logger.warning(
f"[SERVER] Tipo de request não suportado: {type(json_rpc_request)}"
)
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"id": body.get("id"),
"error": {
"code": -32601,
"message": "Method not found",
"data": {"detail": f"Method not supported"},
},
},
)
finally:
# Restore the original db
self.task_manager.db = original_db
# Create appropriate response
return self._create_response(result)
except json.JSONDecodeError as e:
# Error parsing JSON
logger.error(f"Error parsing JSON request: {str(e)}")
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"id": None,
"error": {
"code": -32700,
"message": "Parse error",
"data": {"detail": str(e)},
},
},
)
except Exception as e:
# Other validation errors
logger.error(f"Error validating request: {str(e)}")
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"id": body.get("id") if "body" in locals() else None,
"error": {
"code": -32600,
"message": "Invalid Request",
"data": {"detail": str(e)},
},
},
)
except Exception as e:
logger.error(f"Error processing JSON-RPC request: {str(e)}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"jsonrpc": "2.0",
"id": None,
"error": {
"code": -32603,
"message": "Internal error",
"data": {"detail": str(e)},
},
},
)
def _create_response(self, result: Any) -> Union[JSONResponse, StreamingResponse]:
"""
Create appropriate response based on result type.
Args:
result: Result from task manager
Returns:
JSON or streaming response
"""
if isinstance(result, AsyncIterable):
# Result in streaming (SSE)
async def event_generator():
async for item in result:
if hasattr(item, "model_dump_json"):
yield {"data": item.model_dump_json(exclude_none=True)}
else:
yield {"data": json.dumps(item)}
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
elif hasattr(result, "model_dump"):
# Result is a Pydantic object
return JSONResponse(result.model_dump(exclude_none=True))
else:
# Result is a dictionary or other simple type
return JSONResponse(result)
async def get_agent_card(
self, request: Request, db: Optional[Session] = None
) -> JSONResponse:
"""
Get the agent card.
Args:
request: HTTP request
db: Database session
Returns:
Agent card as JSON
"""
if not self.agent_card:
logger.error("Agent card not configured")
return JSONResponse(
status_code=404, content={"error": "Agent card not configured"}
)
# If there is db, set it temporarily in the task_manager
if db and hasattr(self.task_manager, "db"):
original_db = self.task_manager.db
try:
self.task_manager.db = db
# If it's a Pydantic object, convert to dictionary
if hasattr(self.agent_card, "model_dump"):
return JSONResponse(self.agent_card.model_dump(exclude_none=True))
else:
return JSONResponse(self.agent_card)
finally:
# Restore the original db
self.task_manager.db = original_db
else:
# If it's a Pydantic object, convert to dictionary
if hasattr(self.agent_card, "model_dump"):
return JSONResponse(self.agent_card.model_dump(exclude_none=True))
else:
return JSONResponse(self.agent_card)

View File

@ -0,0 +1,888 @@
"""
A2A Task Manager Service.
This service implements task management for the A2A protocol, handling task lifecycle
including execution, streaming, push notifications, status queries, and cancellation.
"""
import asyncio
import json
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Union, AsyncIterable
from sqlalchemy.orm import Session
from src.schemas.a2a.exceptions import (
TaskNotFoundError,
TaskNotCancelableError,
PushNotificationNotSupportedError,
InternalError,
ContentTypeNotSupportedError,
)
from src.schemas.a2a.types import (
JSONRPCResponse,
TaskIdParams,
TaskQueryParams,
GetTaskRequest,
SendTaskRequest,
CancelTaskRequest,
SetTaskPushNotificationRequest,
GetTaskPushNotificationRequest,
GetTaskResponse,
CancelTaskResponse,
SendTaskResponse,
SetTaskPushNotificationResponse,
GetTaskPushNotificationResponse,
TaskSendParams,
TaskStatus,
TaskState,
TaskResubscriptionRequest,
SendTaskStreamingRequest,
SendTaskStreamingResponse,
Artifact,
PushNotificationConfig,
TaskStatusUpdateEvent,
TaskArtifactUpdateEvent,
JSONRPCError,
TaskPushNotificationConfig,
Message,
TextPart,
Task,
)
from src.services.redis_cache_service import RedisCacheService
from src.utils.a2a_utils import (
are_modalities_compatible,
new_incompatible_types_error,
new_not_implemented_error,
create_task_id,
format_error_response,
)
logger = logging.getLogger(__name__)
class A2ATaskManager:
"""
A2A Task Manager implementation.
This class manages the lifecycle of A2A tasks, including:
- Task submission and execution
- Task status queries
- Task cancellation
- Push notification configuration
- SSE streaming for real-time updates
"""
def __init__(
self,
redis_cache: RedisCacheService,
agent_runner=None,
streaming_service=None,
push_notification_service=None,
db=None,
):
"""
Initialize the A2A Task Manager.
Args:
redis_cache: Redis cache service for task storage
agent_runner: Agent runner service for task execution
streaming_service: Streaming service for SSE
push_notification_service: Service for push notifications
db: Database session
"""
self.redis_cache = redis_cache
self.agent_runner = agent_runner
self.streaming_service = streaming_service
self.push_notification_service = push_notification_service
self.db = db
self.lock = asyncio.Lock()
self.subscriber_lock = asyncio.Lock()
self.task_sse_subscribers = {}
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
"""
Manipula requisição para obter informações sobre uma tarefa.
Args:
request: Requisição Get Task do A2A
Returns:
Resposta com os detalhes da tarefa
"""
try:
task_id = request.params.id
history_length = request.params.historyLength
# Busca dados da tarefa do cache
task_data = await self.redis_cache.get(f"task:{task_id}")
if not task_data:
logger.warning(f"Tarefa não encontrada: {task_id}")
return GetTaskResponse(id=request.id, error=TaskNotFoundError())
# Cria uma instância Task a partir dos dados do cache
task = Task.model_validate(task_data)
# Se o parâmetro historyLength estiver presente, manipula o histórico
if history_length is not None and task.history:
if history_length == 0:
task.history = []
elif len(task.history) > history_length:
task.history = task.history[-history_length:]
return GetTaskResponse(id=request.id, result=task)
except Exception as e:
logger.error(f"Erro ao processar on_get_task: {str(e)}")
return GetTaskResponse(id=request.id, error=InternalError(message=str(e)))
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
"""
Handle request to cancel a running task.
Args:
request: The JSON-RPC request to cancel a task
Returns:
Response with updated task data or error
"""
logger.info(f"Cancelling task {request.params.id}")
task_id_params = request.params
try:
task_data = await self.redis_cache.get(f"task:{task_id_params.id}")
if not task_data:
logger.warning(f"Task {task_id_params.id} not found for cancellation")
return CancelTaskResponse(id=request.id, error=TaskNotFoundError())
# Check if task can be cancelled
current_state = task_data.get("status", {}).get("state")
if current_state not in [TaskState.SUBMITTED, TaskState.WORKING]:
logger.warning(
f"Task {task_id_params.id} in state {current_state} cannot be cancelled"
)
return CancelTaskResponse(id=request.id, error=TaskNotCancelableError())
# Update task status to cancelled
task_data["status"] = {
"state": TaskState.CANCELED,
"timestamp": datetime.now().isoformat(),
"message": {
"role": "agent",
"parts": [{"type": "text", "text": "Task cancelled by user"}],
},
}
# Save updated task data
await self.redis_cache.set(f"task:{task_id_params.id}", task_data)
# Send push notification if configured
await self._send_push_notification_for_task(
task_id_params.id, "canceled", system_message="Task cancelled by user"
)
# Publish event to SSE subscribers
await self._publish_task_update(
task_id_params.id,
TaskStatusUpdateEvent(
id=task_id_params.id,
status=TaskStatus(
state=TaskState.CANCELED,
timestamp=datetime.now(),
message=Message(
role="agent",
parts=[TextPart(text="Task cancelled by user")],
),
),
final=True,
),
)
return CancelTaskResponse(id=request.id, result=task_data)
except Exception as e:
logger.error(f"Error cancelling task: {str(e)}", exc_info=True)
return CancelTaskResponse(
id=request.id,
error=InternalError(message=f"Error cancelling task: {str(e)}"),
)
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
"""
Manipula requisição para enviar uma nova tarefa.
Args:
request: Requisição de envio de tarefa
Returns:
Resposta com os detalhes da tarefa criada
"""
try:
params = request.params
task_id = params.id
logger.info(f"Recebendo tarefa {task_id}")
# Verifica se já existe uma tarefa com esse ID
existing_task = await self.redis_cache.get(f"task:{task_id}")
if existing_task:
# Se a tarefa já existe e está em progresso, retorna a tarefa atual
if existing_task.get("status", {}).get("state") in [
TaskState.WORKING,
TaskState.COMPLETED,
]:
logger.info(
f"Tarefa {task_id} já existe e está em progresso/concluída"
)
return SendTaskResponse(
id=request.id, result=Task.model_validate(existing_task)
)
# Se a tarefa existe mas falhou ou foi cancelada, podemos reprocessá-la
logger.info(f"Reprocessando tarefa existente {task_id}")
# Verifica compatibilidade de modalidades
server_output_modes = []
if self.agent_runner:
# Tenta obter modos suportados do agente
try:
server_output_modes = await self.agent_runner.get_supported_modes()
except Exception as e:
logger.warning(f"Erro ao obter modos suportados: {str(e)}")
server_output_modes = ["text"] # Fallback para texto
if not are_modalities_compatible(
server_output_modes, params.acceptedOutputModes
):
logger.warning(
f"Modos incompatíveis: servidor={server_output_modes}, cliente={params.acceptedOutputModes}"
)
return SendTaskResponse(
id=request.id, error=ContentTypeNotSupportedError()
)
# Cria dados da tarefa
task_data = await self._create_task_data(params)
# Armazena a tarefa no cache
await self.redis_cache.set(f"task:{task_id}", task_data)
# Configura notificações push, se fornecidas
if params.pushNotification:
await self.redis_cache.set(
f"task_notification:{task_id}", params.pushNotification.model_dump()
)
# Inicia a execução da tarefa em background
asyncio.create_task(self._execute_task(task_data, params))
# Converte para objeto Task e retorna
task = Task.model_validate(task_data)
return SendTaskResponse(id=request.id, result=task)
except Exception as e:
logger.error(f"Erro ao processar on_send_task: {str(e)}")
return SendTaskResponse(id=request.id, error=InternalError(message=str(e)))
async def on_send_task_subscribe(
self, request: SendTaskStreamingRequest
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
"""
Handle request to send a task and subscribe to streaming updates.
Args:
request: The JSON-RPC request to send a task with streaming
Returns:
Stream of events or error response
"""
logger.info(f"Sending task with streaming {request.params.id}")
task_send_params = request.params
try:
# Check output mode compatibility
if not are_modalities_compatible(
["text", "application/json"], # Default supported modes
task_send_params.acceptedOutputModes,
):
return new_incompatible_types_error(request.id)
# Create initial task data
task_data = await self._create_task_data(task_send_params)
# Setup SSE consumer
sse_queue = await self._setup_sse_consumer(task_send_params.id)
# Execute task asynchronously (fire and forget)
asyncio.create_task(self._execute_task(task_data, task_send_params))
# Return generator to dequeue events for SSE
return self._dequeue_events_for_sse(
request.id, task_send_params.id, sse_queue
)
except Exception as e:
logger.error(f"Error setting up streaming task: {str(e)}", exc_info=True)
return SendTaskStreamingResponse(
id=request.id,
error=InternalError(
message=f"Error setting up streaming task: {str(e)}"
),
)
async def on_set_task_push_notification(
self, request: SetTaskPushNotificationRequest
) -> SetTaskPushNotificationResponse:
"""
Configure push notifications for a task.
Args:
request: The JSON-RPC request to set push notification
Returns:
Response with configuration or error
"""
logger.info(f"Setting push notification for task {request.params.id}")
task_notification_params = request.params
try:
if not self.push_notification_service:
logger.warning("Push notifications not supported")
return SetTaskPushNotificationResponse(
id=request.id, error=PushNotificationNotSupportedError()
)
# Check if task exists
task_data = await self.redis_cache.get(
f"task:{task_notification_params.id}"
)
if not task_data:
logger.warning(
f"Task {task_notification_params.id} not found for setting push notification"
)
return SetTaskPushNotificationResponse(
id=request.id, error=TaskNotFoundError()
)
# Save push notification config
config = {
"url": task_notification_params.pushNotificationConfig.url,
"headers": {}, # Add auth headers if needed
}
await self.redis_cache.set(
f"task:{task_notification_params.id}:push", config
)
return SetTaskPushNotificationResponse(
id=request.id, result=task_notification_params
)
except Exception as e:
logger.error(f"Error setting push notification: {str(e)}", exc_info=True)
return SetTaskPushNotificationResponse(
id=request.id,
error=InternalError(
message=f"Error setting push notification: {str(e)}"
),
)
async def on_get_task_push_notification(
self, request: GetTaskPushNotificationRequest
) -> GetTaskPushNotificationResponse:
"""
Get push notification configuration for a task.
Args:
request: The JSON-RPC request to get push notification config
Returns:
Response with configuration or error
"""
logger.info(f"Getting push notification for task {request.params.id}")
task_params = request.params
try:
if not self.push_notification_service:
logger.warning("Push notifications not supported")
return GetTaskPushNotificationResponse(
id=request.id, error=PushNotificationNotSupportedError()
)
# Check if task exists
task_data = await self.redis_cache.get(f"task:{task_params.id}")
if not task_data:
logger.warning(
f"Task {task_params.id} not found for getting push notification"
)
return GetTaskPushNotificationResponse(
id=request.id, error=TaskNotFoundError()
)
# Get push notification config
config = await self.redis_cache.get(f"task:{task_params.id}:push")
if not config:
logger.warning(f"No push notification config for task {task_params.id}")
return GetTaskPushNotificationResponse(
id=request.id,
error=InternalError(
message="No push notification configuration found"
),
)
result = TaskPushNotificationConfig(
id=task_params.id,
pushNotificationConfig=PushNotificationConfig(
url=config.get("url"), token=None, authentication=None
),
)
return GetTaskPushNotificationResponse(id=request.id, result=result)
except Exception as e:
logger.error(f"Error getting push notification: {str(e)}", exc_info=True)
return GetTaskPushNotificationResponse(
id=request.id,
error=InternalError(
message=f"Error getting push notification: {str(e)}"
),
)
async def on_resubscribe_to_task(
self, request: TaskResubscriptionRequest
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
"""
Resubscribe to a task's streaming events.
Args:
request: The JSON-RPC request to resubscribe
Returns:
Stream of events or error response
"""
logger.info(f"Resubscribing to task {request.params.id}")
task_params = request.params
try:
# Check if task exists
task_data = await self.redis_cache.get(f"task:{task_params.id}")
if not task_data:
logger.warning(f"Task {task_params.id} not found for resubscription")
return JSONRPCResponse(id=request.id, error=TaskNotFoundError())
# Setup SSE consumer with resubscribe flag
try:
sse_queue = await self._setup_sse_consumer(
task_params.id, is_resubscribe=True
)
except ValueError:
logger.warning(
f"Task {task_params.id} not available for resubscription"
)
return JSONRPCResponse(
id=request.id,
error=InternalError(
message="Task not available for resubscription"
),
)
# Send initial status update to the new subscriber
status = task_data.get("status", {})
final = status.get("state") in [
TaskState.COMPLETED,
TaskState.FAILED,
TaskState.CANCELED,
]
# Convert to TaskStatus object
task_status = TaskStatus(
state=status.get("state", TaskState.UNKNOWN),
timestamp=datetime.fromisoformat(
status.get("timestamp", datetime.now().isoformat())
),
message=status.get("message"),
)
# Publish to the specific queue
await sse_queue.put(
TaskStatusUpdateEvent(
id=task_params.id, status=task_status, final=final
)
)
# Return generator to dequeue events for SSE
return self._dequeue_events_for_sse(request.id, task_params.id, sse_queue)
except Exception as e:
logger.error(f"Error resubscribing to task: {str(e)}", exc_info=True)
return JSONRPCResponse(
id=request.id,
error=InternalError(message=f"Error resubscribing to task: {str(e)}"),
)
async def _create_task_data(self, params: TaskSendParams) -> Dict[str, Any]:
"""
Create initial task data structure.
Args:
params: Task send parameters
Returns:
Task data dictionary
"""
# Create task with initial status
task_data = {
"id": params.id,
"sessionId": params.sessionId,
"status": {
"state": TaskState.SUBMITTED,
"timestamp": datetime.now().isoformat(),
"message": None,
"error": None,
},
"artifacts": [],
"history": [params.message.model_dump()],
"metadata": params.metadata or {},
}
# Save task data
await self.redis_cache.set(f"task:{params.id}", task_data)
return task_data
async def _execute_task(self, task: Dict[str, Any], params: TaskSendParams) -> None:
"""
Executa uma tarefa usando o adaptador do agente.
Esta função é responsável pela execução real da tarefa pelo agente,
atualizando seu status conforme o progresso.
Args:
task: Dados da tarefa a ser executada
params: Parâmetros de envio da tarefa
"""
task_id = task["id"]
agent_id = params.agentId
message_text = ""
# Extrai o texto da mensagem
if params.message and params.message.parts:
for part in params.message.parts:
if part.type == "text":
message_text += part.text
if not message_text:
await self._update_task_status(
task_id, TaskState.FAILED, "Mensagem não contém texto", final=True
)
return
# Verificamos se é uma execução em andamento
task_status = task.get("status", {})
if task_status.get("state") in [TaskState.WORKING, TaskState.COMPLETED]:
logger.info(f"Tarefa {task_id} já está em execução ou concluída")
return
try:
# Atualiza para estado "working"
await self._update_task_status(
task_id, TaskState.WORKING, "Processando solicitação"
)
# Executa o agente
if self.agent_runner:
response = await self.agent_runner.run_agent(
agent_id=agent_id,
message=message_text,
session_id=params.sessionId,
task_id=task_id,
)
# Processa a resposta do agente
if response and isinstance(response, dict):
# Extrai texto da resposta
response_text = response.get("text", "")
if not response_text and "message" in response:
message = response.get("message", {})
parts = message.get("parts", [])
for part in parts:
if part.get("type") == "text":
response_text += part.get("text", "")
# Constrói a mensagem final do agente
if response_text:
# Cria um artefato para a resposta
artifact = Artifact(
name="response",
parts=[TextPart(text=response_text)],
index=0,
lastChunk=True,
)
# Adiciona o artefato à tarefa
await self._add_task_artifact(task_id, artifact)
# Atualiza o status da tarefa para completado
await self._update_task_status(
task_id, TaskState.COMPLETED, response_text, final=True
)
else:
await self._update_task_status(
task_id,
TaskState.FAILED,
"O agente não retornou uma resposta válida",
final=True,
)
else:
await self._update_task_status(
task_id,
TaskState.FAILED,
"Resposta inválida do agente",
final=True,
)
else:
await self._update_task_status(
task_id,
TaskState.FAILED,
"Adaptador do agente não configurado",
final=True,
)
except Exception as e:
logger.error(f"Erro na execução da tarefa {task_id}: {str(e)}")
await self._update_task_status(
task_id, TaskState.FAILED, f"Erro ao processar: {str(e)}", final=True
)
async def _update_task_status(
self, task_id: str, state: TaskState, message_text: str, final: bool = False
) -> None:
"""
Atualiza o status de uma tarefa.
Args:
task_id: ID da tarefa a ser atualizada
state: Novo estado da tarefa
message_text: Texto da mensagem associada ao status
final: Indica se este é o status final da tarefa
"""
try:
# Busca dados atuais da tarefa
task_data = await self.redis_cache.get(f"task:{task_id}")
if not task_data:
logger.warning(
f"Não foi possível atualizar status: tarefa {task_id} não encontrada"
)
return
# Cria objeto de status com a mensagem
agent_message = Message(
role="agent",
parts=[TextPart(text=message_text)],
metadata={"timestamp": datetime.now().isoformat()},
)
status = TaskStatus(
state=state, message=agent_message, timestamp=datetime.now()
)
# Atualiza o status na tarefa
task_data["status"] = status.model_dump(exclude_none=True)
# Atualiza o histórico, se existir
if "history" not in task_data:
task_data["history"] = []
# Adiciona a mensagem ao histórico
task_data["history"].append(agent_message.model_dump(exclude_none=True))
# Armazena a tarefa atualizada
await self.redis_cache.set(f"task:{task_id}", task_data)
# Cria evento de atualização de status
status_event = TaskStatusUpdateEvent(id=task_id, status=status, final=final)
# Publica atualização
await self._publish_task_update(task_id, status_event)
# Envia notificação push, se configurada
if final or state in [
TaskState.FAILED,
TaskState.COMPLETED,
TaskState.CANCELED,
]:
await self._send_push_notification_for_task(
task_id=task_id, state=state, message_text=message_text
)
except Exception as e:
logger.error(f"Erro ao atualizar status da tarefa {task_id}: {str(e)}")
async def _add_task_artifact(self, task_id: str, artifact: Artifact) -> None:
"""
Add an artifact to a task and publish the update.
Args:
task_id: Task ID
artifact: Artifact to add
"""
logger.info(f"Adding artifact to task {task_id}")
# Update task data
task_data = await self.redis_cache.get(f"task:{task_id}")
if task_data:
if "artifacts" not in task_data:
task_data["artifacts"] = []
# Convert artifact to dict
artifact_dict = artifact.model_dump()
task_data["artifacts"].append(artifact_dict)
await self.redis_cache.set(f"task:{task_id}", task_data)
# Create artifact update event
event = TaskArtifactUpdateEvent(id=task_id, artifact=artifact)
# Publish event
await self._publish_task_update(task_id, event)
async def _publish_task_update(
self, task_id: str, event: Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent]
) -> None:
"""
Publish a task update event to all subscribers.
Args:
task_id: Task ID
event: Event to publish
"""
async with self.subscriber_lock:
if task_id not in self.task_sse_subscribers:
return
subscribers = self.task_sse_subscribers[task_id]
for subscriber in subscribers:
try:
await subscriber.put(event)
except Exception as e:
logger.error(f"Error publishing event to subscriber: {str(e)}")
async def _send_push_notification_for_task(
self,
task_id: str,
state: str,
message_text: str = None,
system_message: str = None,
) -> None:
"""
Send push notification for a task if configured.
Args:
task_id: Task ID
state: Task state
message_text: Optional message text
system_message: Optional system message
"""
if not self.push_notification_service:
return
try:
# Get push notification config
config = await self.redis_cache.get(f"task:{task_id}:push")
if not config:
return
# Create message if provided
message = None
if message_text:
message = {
"role": "agent",
"parts": [{"type": "text", "text": message_text}],
}
elif system_message:
# We use 'agent' instead of 'system' since Message only accepts 'user' or 'agent'
message = {
"role": "agent",
"parts": [{"type": "text", "text": system_message}],
}
# Send notification
await self.push_notification_service.send_notification(
url=config["url"],
task_id=task_id,
state=state,
message=message,
headers=config.get("headers", {}),
)
except Exception as e:
logger.error(
f"Error sending push notification for task {task_id}: {str(e)}"
)
async def _setup_sse_consumer(
self, task_id: str, is_resubscribe: bool = False
) -> asyncio.Queue:
"""
Set up an SSE consumer queue for a task.
Args:
task_id: Task ID
is_resubscribe: Whether this is a resubscription
Returns:
Queue for events
Raises:
ValueError: If resubscribing to non-existent task
"""
async with self.subscriber_lock:
if task_id not in self.task_sse_subscribers:
if is_resubscribe:
raise ValueError("Task not found for resubscription")
self.task_sse_subscribers[task_id] = []
queue = asyncio.Queue()
self.task_sse_subscribers[task_id].append(queue)
return queue
async def _dequeue_events_for_sse(
self, request_id: str, task_id: str, event_queue: asyncio.Queue
) -> AsyncIterable[SendTaskStreamingResponse]:
"""
Dequeue and yield events for SSE streaming.
Args:
request_id: Request ID
task_id: Task ID
event_queue: Queue for events
Yields:
SSE events wrapped in SendTaskStreamingResponse
"""
try:
while True:
event = await event_queue.get()
if isinstance(event, JSONRPCError):
yield SendTaskStreamingResponse(id=request_id, error=event)
break
yield SendTaskStreamingResponse(id=request_id, result=event)
# Check if this is the final event
is_final = False
if hasattr(event, "final") and event.final:
is_final = True
if is_final:
break
finally:
# Clean up the subscription when done
async with self.subscriber_lock:
if task_id in self.task_sse_subscribers:
try:
self.task_sse_subscribers[task_id].remove(event_queue)
# Remove the task from the dict if no more subscribers
if not self.task_sse_subscribers[task_id]:
del self.task_sse_subscribers[task_id]
except ValueError:
pass # Queue might have been removed already

View File

@ -3,7 +3,7 @@ from sqlalchemy.exc import SQLAlchemyError
from fastapi import HTTPException, status
from src.models.models import Agent
from src.schemas.schemas import AgentCreate
from typing import List, Optional, Dict, Any
from typing import List, Optional, Dict, Any, Union
from src.services.mcp_server_service import get_mcp_server
import uuid
import logging
@ -20,9 +20,17 @@ def validate_sub_agents(db: Session, sub_agents: List[uuid.UUID]) -> bool:
return True
def get_agent(db: Session, agent_id: uuid.UUID) -> Optional[Agent]:
def get_agent(db: Session, agent_id: Union[uuid.UUID, str]) -> Optional[Agent]:
"""Search for an agent by ID"""
try:
# Convert to UUID if it's a string
if isinstance(agent_id, str):
try:
agent_id = uuid.UUID(agent_id)
except ValueError:
logger.warning(f"Invalid agent ID: {agent_id}")
return None
agent = db.query(Agent).filter(Agent.id == agent_id).first()
if not agent:
logger.warning(f"Agent not found: {agent_id}")

View File

@ -0,0 +1,265 @@
"""
Push Notification Authentication Service.
This service implements JWT authentication for A2A push notifications,
allowing agents to send authenticated notifications and clients to verify
the authenticity of received notifications.
"""
from jwcrypto import jwk
import uuid
import time
import json
import hashlib
import httpx
import logging
import jwt
from jwt import PyJWK, PyJWKClient
from fastapi import Request
from starlette.responses import JSONResponse
from typing import Dict, Any, Optional
logger = logging.getLogger(__name__)
AUTH_HEADER_PREFIX = "Bearer "
class PushNotificationAuth:
"""
Base class for push notification authentication.
Contains common methods used by both the sender and the receiver
of push notifications.
"""
def _calculate_request_body_sha256(self, data: Dict[str, Any]) -> str:
"""
Calculates the SHA256 hash of the request body.
This logic needs to be the same for the agent that signs the payload
and for the client that verifies it.
Args:
data: Request body data
Returns:
SHA256 hash as a hexadecimal string
"""
body_str = json.dumps(
data,
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
)
return hashlib.sha256(body_str.encode()).hexdigest()
class PushNotificationSenderAuth(PushNotificationAuth):
"""
Authentication for the push notification sender.
This class is used by the A2A server to authenticate notifications
sent to callback URLs of clients.
"""
def __init__(self):
"""
Initialize the push notification sender authentication service.
"""
self.public_keys = []
self.private_key_jwk = None
@staticmethod
async def verify_push_notification_url(url: str) -> bool:
"""
Verifies if a push notification URL is valid and responds correctly.
Sends a validation token and verifies if the response contains the same token.
Args:
url: URL to be verified
Returns:
True if the URL is verified successfully, False otherwise
"""
async with httpx.AsyncClient(timeout=10) as client:
try:
validation_token = str(uuid.uuid4())
response = await client.get(
url, params={"validationToken": validation_token}
)
response.raise_for_status()
is_verified = response.text == validation_token
logger.info(f"Push notification URL verified: {url} => {is_verified}")
return is_verified
except Exception as e:
logger.warning(f"Error verifying push notification URL {url}: {e}")
return False
def generate_jwk(self):
"""
Generates a new JWK pair for signing.
The key pair is used to sign push notifications.
The public key is available via the JWKS endpoint.
"""
key = jwk.JWK.generate(kty="RSA", size=2048, kid=str(uuid.uuid4()), use="sig")
self.public_keys.append(key.export_public(as_dict=True))
self.private_key_jwk = PyJWK.from_json(key.export_private())
def handle_jwks_endpoint(self, _request: Request) -> JSONResponse:
"""
Handles the JWKS endpoint to allow clients to obtain the public keys.
Args:
_request: HTTP request
Returns:
JSON response with the public keys
"""
return JSONResponse({"keys": self.public_keys})
def _generate_jwt(self, data: Dict[str, Any]) -> str:
"""
Generates a JWT token by signing the SHA256 hash of the payload and the timestamp.
The payload is signed with the private key to ensure integrity.
The timestamp (iat) prevents replay attacks.
Args:
data: Payload data
Returns:
Signed JWT token
"""
iat = int(time.time())
return jwt.encode(
{
"iat": iat,
"request_body_sha256": self._calculate_request_body_sha256(data),
},
key=self.private_key_jwk.key,
headers={"kid": self.private_key_jwk.key_id},
algorithm="RS256",
)
async def send_push_notification(self, url: str, data: Dict[str, Any]) -> bool:
"""
Sends an authenticated push notification to the specified URL.
Args:
url: URL to send the notification
data: Notification data
Returns:
True if the notification was sent successfully, False otherwise
"""
if not self.private_key_jwk:
logger.error(
"Attempt to send push notification without generating JWK keys"
)
return False
try:
jwt_token = self._generate_jwt(data)
headers = {"Authorization": f"Bearer {jwt_token}"}
async with httpx.AsyncClient(timeout=10) as client:
response = await client.post(url, json=data, headers=headers)
response.raise_for_status()
logger.info(f"Push notification sent to URL: {url}")
return True
except Exception as e:
logger.warning(f"Error sending push notification to URL {url}: {e}")
return False
class PushNotificationReceiverAuth(PushNotificationAuth):
"""
Authentication for the push notification receiver.
This class is used by clients to verify the authenticity
of notifications received from the A2A server.
"""
def __init__(self):
"""
Initialize the push notification receiver authentication service.
"""
self.jwks_client = None
async def load_jwks(self, jwks_url: str):
"""
Loads the public JWKS keys from a URL.
Args:
jwks_url: URL of the JWKS endpoint
"""
self.jwks_client = PyJWKClient(jwks_url)
async def verify_push_notification(self, request: Request) -> bool:
"""
Verifies the authenticity of a push notification.
Args:
request: HTTP request containing the notification
Returns:
True if the notification is authentic, False otherwise
Raises:
ValueError: If the token is invalid or expired
"""
if not self.jwks_client:
logger.error("Attempt to verify notification without loading JWKS keys")
return False
# Verify authentication header
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX):
logger.warning("Invalid authorization header")
return False
try:
# Extract and verify token
token = auth_header[len(AUTH_HEADER_PREFIX) :]
signing_key = self.jwks_client.get_signing_key_from_jwt(token)
# Decode token
decode_token = jwt.decode(
token,
signing_key.key,
options={"require": ["iat", "request_body_sha256"]},
algorithms=["RS256"],
)
# Verify request body integrity
body_data = await request.json()
actual_body_sha256 = self._calculate_request_body_sha256(body_data)
if actual_body_sha256 != decode_token["request_body_sha256"]:
# The payload signature does not match the hash in the signed token
logger.warning("Request body hash does not match the token")
raise ValueError("Invalid request body")
# Verify token age (maximum 5 minutes)
if time.time() - decode_token["iat"] > 60 * 5:
# Do not allow notifications older than 5 minutes
# This prevents replay attacks
logger.warning("Token expired")
raise ValueError("Token expired")
return True
except Exception as e:
logger.error(f"Error verifying push notification: {e}")
return False
# Global instance of the push notification sender authentication service
push_notification_auth = PushNotificationSenderAuth()
# Generate JWK keys on initialization
push_notification_auth.generate_jwk()

View File

@ -4,6 +4,8 @@ from datetime import datetime
from typing import Dict, Any, Optional
import asyncio
from src.services.push_notification_auth_service import push_notification_auth
logger = logging.getLogger(__name__)
@ -20,10 +22,24 @@ class PushNotificationService:
headers: Optional[Dict[str, str]] = None,
max_retries: int = 3,
retry_delay: float = 1.0,
use_jwt_auth: bool = True,
) -> bool:
"""
Envia uma notificação push para a URL especificada.
Implementa retry com backoff exponencial.
Send a push notification to the specified URL.
Implements exponential backoff retry.
Args:
url: URL to send the notification
task_id: Task ID
state: Task state
message: Optional message
headers: Optional HTTP headers
max_retries: Maximum number of retries
retry_delay: Initial delay between retries
use_jwt_auth: Whether to use JWT authentication
Returns:
True if the notification was sent successfully, False otherwise
"""
payload = {
"taskId": task_id,
@ -32,18 +48,37 @@ class PushNotificationService:
"message": message,
}
# First URL verification
if use_jwt_auth:
is_url_valid = await push_notification_auth.verify_push_notification_url(
url
)
if not is_url_valid:
logger.warning(f"Invalid push notification URL: {url}")
return False
for attempt in range(max_retries):
try:
async with self.session.post(
url, json=payload, headers=headers or {}, timeout=10
) as response:
if response.status in (200, 201, 202, 204):
if use_jwt_auth:
# Use JWT authentication
success = await push_notification_auth.send_push_notification(
url, payload
)
if success:
return True
else:
logger.warning(
f"Push notification failed with status {response.status}. "
f"Attempt {attempt + 1}/{max_retries}"
)
else:
# Legacy method without JWT authentication
async with self.session.post(
url, json=payload, headers=headers or {}, timeout=10
) as response:
if response.status in (200, 201, 202, 204):
logger.info(f"Push notification sent to URL: {url}")
return True
else:
logger.warning(
f"Failed to send push notification with status {response.status}. "
f"Attempt {attempt + 1}/{max_retries}"
)
except Exception as e:
logger.error(
f"Error sending push notification: {str(e)}. "
@ -56,9 +91,9 @@ class PushNotificationService:
return False
async def close(self):
"""Fecha a sessão HTTP"""
"""Close the HTTP session"""
await self.session.close()
# Instância global do serviço
# Global instance of the push notification service
push_notification_service = PushNotificationService()

View File

@ -0,0 +1,556 @@
"""
Cache Redis service for the A2A protocol.
This service provides an interface for storing and retrieving data related to tasks,
push notification configurations, and other A2A-related data.
"""
import json
import logging
from typing import Any, Dict, List, Optional, Union
import asyncio
import redis.asyncio as aioredis
from redis.exceptions import RedisError
from src.config.redis import get_redis_config, get_a2a_config
import threading
import time
logger = logging.getLogger(__name__)
class _InMemoryCacheFallback:
"""
Fallback in-memory cache implementation for when Redis is not available.
This should only be used for development or testing environments.
"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
"""Singleton implementation."""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
"""Initialize cache storage."""
if not getattr(self, "_initialized", False):
with self._lock:
if not getattr(self, "_initialized", False):
self._data = {}
self._ttls = {}
self._hash_data = {}
self._list_data = {}
self._data_lock = threading.Lock()
self._initialized = True
logger.warning(
"Initializing in-memory cache fallback (not for production)"
)
async def set(self, key, value, ex=None):
"""Set a key with optional expiration."""
with self._data_lock:
self._data[key] = value
if ex is not None:
self._ttls[key] = time.time() + ex
elif key in self._ttls:
del self._ttls[key]
return True
async def setex(self, key, ex, value):
"""Set a key with expiration."""
return await self.set(key, value, ex)
async def get(self, key):
"""Get a key value."""
with self._data_lock:
# Check if expired
if key in self._ttls and time.time() > self._ttls[key]:
del self._data[key]
del self._ttls[key]
return None
return self._data.get(key)
async def delete(self, key):
"""Delete a key."""
with self._data_lock:
if key in self._data:
del self._data[key]
if key in self._ttls:
del self._ttls[key]
return 1
return 0
async def exists(self, key):
"""Check if key exists."""
with self._data_lock:
if key in self._ttls and time.time() > self._ttls[key]:
del self._data[key]
del self._ttls[key]
return 0
return 1 if key in self._data else 0
async def hset(self, key, field, value):
"""Set a hash field."""
with self._data_lock:
if key not in self._hash_data:
self._hash_data[key] = {}
self._hash_data[key][field] = value
return 1
async def hget(self, key, field):
"""Get a hash field."""
with self._data_lock:
if key not in self._hash_data:
return None
return self._hash_data[key].get(field)
async def hdel(self, key, field):
"""Delete a hash field."""
with self._data_lock:
if key in self._hash_data and field in self._hash_data[key]:
del self._hash_data[key][field]
return 1
return 0
async def hgetall(self, key):
"""Get all hash fields."""
with self._data_lock:
if key not in self._hash_data:
return {}
return dict(self._hash_data[key])
async def rpush(self, key, value):
"""Add to a list."""
with self._data_lock:
if key not in self._list_data:
self._list_data[key] = []
self._list_data[key].append(value)
return len(self._list_data[key])
async def lrange(self, key, start, end):
"""Get range from list."""
with self._data_lock:
if key not in self._list_data:
return []
# Handle negative indices
if end < 0:
end = len(self._list_data[key]) + end + 1
return self._list_data[key][start:end]
async def expire(self, key, seconds):
"""Set expiration on key."""
with self._data_lock:
if key in self._data:
self._ttls[key] = time.time() + seconds
return 1
return 0
async def flushdb(self):
"""Clear all data."""
with self._data_lock:
self._data.clear()
self._ttls.clear()
self._hash_data.clear()
self._list_data.clear()
return True
async def keys(self, pattern="*"):
"""Get keys matching pattern."""
with self._data_lock:
# Clean expired keys
now = time.time()
expired_keys = [k for k, exp in self._ttls.items() if now > exp]
for k in expired_keys:
if k in self._data:
del self._data[k]
del self._ttls[k]
# Simple pattern matching
result = []
if pattern == "*":
result = list(self._data.keys())
elif pattern.endswith("*"):
prefix = pattern[:-1]
result = [k for k in self._data.keys() if k.startswith(prefix)]
elif pattern.startswith("*"):
suffix = pattern[1:]
result = [k for k in self._data.keys() if k.endswith(suffix)]
else:
if pattern in self._data:
result = [pattern]
return result
async def ping(self):
"""Test connection."""
return True
class RedisCacheService:
"""
Cache service using Redis for storing A2A-related data.
This implementation uses a real Redis connection for distributed caching
and data persistence across multiple instances.
If Redis is not available, falls back to an in-memory implementation.
"""
def __init__(self, redis_url: Optional[str] = None):
"""
Initialize the Redis cache service.
Args:
redis_url: Redis server URL (optional, defaults to config value)
"""
if redis_url:
self._redis_url = redis_url
else:
# Construir URL a partir dos componentes de configuração
config = get_redis_config()
protocol = "rediss" if config.get("ssl", False) else "redis"
auth = f":{config['password']}@" if config.get("password") else ""
self._redis_url = (
f"{protocol}://{auth}{config['host']}:{config['port']}/{config['db']}"
)
self._redis = None
self._in_memory_mode = False
self._connecting = False
self._connection_lock = asyncio.Lock()
logger.info(f"Initializing RedisCacheService with URL: {self._redis_url}")
async def _get_redis(self):
"""
Get a Redis connection, creating it if necessary.
Falls back to in-memory implementation if Redis is not available.
Returns:
Redis connection or in-memory fallback
"""
if self._redis is not None:
return self._redis
async with self._connection_lock:
if self._redis is None and not self._connecting:
try:
self._connecting = True
logger.info(f"Connecting to Redis at {self._redis_url}")
self._redis = aioredis.from_url(
self._redis_url, encoding="utf-8", decode_responses=True
)
# Teste de conexão
await self._redis.ping()
logger.info("Redis connection successful")
except Exception as e:
logger.error(f"Error connecting to Redis: {str(e)}")
logger.warning(
"Falling back to in-memory cache (not suitable for production)"
)
self._redis = _InMemoryCacheFallback()
self._in_memory_mode = True
finally:
self._connecting = False
return self._redis
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
"""
Store a value in the cache.
Args:
key: Key for the value
value: Value to store
ttl: Time to live in seconds (optional)
"""
try:
redis = await self._get_redis()
# Convert dict/list to JSON string
if isinstance(value, (dict, list)):
value = json.dumps(value)
if ttl:
await redis.setex(key, ttl, value)
else:
await redis.set(key, value)
logger.debug(f"Set cache key: {key}")
except Exception as e:
logger.error(f"Error setting Redis key {key}: {str(e)}")
async def get(self, key: str) -> Optional[Any]:
"""
Retrieve a value from the cache.
Args:
key: Key for the value to retrieve
Returns:
The stored value or None if not found
"""
try:
redis = await self._get_redis()
value = await redis.get(key)
if value is None:
return None
try:
# Try to parse as JSON
return json.loads(value)
except json.JSONDecodeError:
# Return as-is if not JSON
return value
except Exception as e:
logger.error(f"Error getting Redis key {key}: {str(e)}")
return None
async def delete(self, key: str) -> bool:
"""
Remove a value from the cache.
Args:
key: Key for the value to remove
Returns:
True if the value was removed, False if it didn't exist
"""
try:
redis = await self._get_redis()
result = await redis.delete(key)
return result > 0
except Exception as e:
logger.error(f"Error deleting Redis key {key}: {str(e)}")
return False
async def exists(self, key: str) -> bool:
"""
Check if a key exists in the cache.
Args:
key: Key to check
Returns:
True if the key exists, False otherwise
"""
try:
redis = await self._get_redis()
return await redis.exists(key) > 0
except Exception as e:
logger.error(f"Error checking Redis key {key}: {str(e)}")
return False
async def set_hash(self, key: str, field: str, value: Any) -> None:
"""
Store a value in a hash.
Args:
key: Hash key
field: Hash field
value: Value to store
"""
try:
redis = await self._get_redis()
# Convert dict/list to JSON string
if isinstance(value, (dict, list)):
value = json.dumps(value)
await redis.hset(key, field, value)
logger.debug(f"Set hash field: {key}:{field}")
except Exception as e:
logger.error(f"Error setting Redis hash {key}:{field}: {str(e)}")
async def get_hash(self, key: str, field: str) -> Optional[Any]:
"""
Retrieve a value from a hash.
Args:
key: Hash key
field: Hash field
Returns:
The stored value or None if not found
"""
try:
redis = await self._get_redis()
value = await redis.hget(key, field)
if value is None:
return None
try:
# Try to parse as JSON
return json.loads(value)
except json.JSONDecodeError:
# Return as-is if not JSON
return value
except Exception as e:
logger.error(f"Error getting Redis hash {key}:{field}: {str(e)}")
return None
async def delete_hash(self, key: str, field: str) -> bool:
"""
Remove a value from a hash.
Args:
key: Hash key
field: Hash field
Returns:
True if the value was removed, False if it didn't exist
"""
try:
redis = await self._get_redis()
result = await redis.hdel(key, field)
return result > 0
except Exception as e:
logger.error(f"Error deleting Redis hash {key}:{field}: {str(e)}")
return False
async def get_all_hash(self, key: str) -> Dict[str, Any]:
"""
Retrieve all values from a hash.
Args:
key: Hash key
Returns:
Dictionary with all hash values
"""
try:
redis = await self._get_redis()
result_dict = await redis.hgetall(key)
if not result_dict:
return {}
# Try to parse each value as JSON
parsed_dict = {}
for field, value in result_dict.items():
try:
parsed_dict[field] = json.loads(value)
except json.JSONDecodeError:
parsed_dict[field] = value
return parsed_dict
except Exception as e:
logger.error(f"Error getting all Redis hash fields for {key}: {str(e)}")
return {}
async def push_list(self, key: str, value: Any) -> int:
"""
Add a value to the end of a list.
Args:
key: List key
value: Value to add
Returns:
Size of the list after the addition
"""
try:
redis = await self._get_redis()
# Convert dict/list to JSON string
if isinstance(value, (dict, list)):
value = json.dumps(value)
return await redis.rpush(key, value)
except Exception as e:
logger.error(f"Error pushing to Redis list {key}: {str(e)}")
return 0
async def get_list(self, key: str, start: int = 0, end: int = -1) -> List[Any]:
"""
Retrieve values from a list.
Args:
key: List key
start: Initial index (inclusive)
end: Final index (inclusive), -1 for all
Returns:
List with the retrieved values
"""
try:
redis = await self._get_redis()
values = await redis.lrange(key, start, end)
if not values:
return []
# Try to parse each value as JSON
result = []
for value in values:
try:
result.append(json.loads(value))
except json.JSONDecodeError:
result.append(value)
return result
except Exception as e:
logger.error(f"Error getting Redis list {key}: {str(e)}")
return []
async def expire(self, key: str, ttl: int) -> bool:
"""
Set a time-to-live for a key.
Args:
key: Key
ttl: Time-to-live in seconds
Returns:
True if the key exists and the TTL was set, False otherwise
"""
try:
redis = await self._get_redis()
return await redis.expire(key, ttl)
except Exception as e:
logger.error(f"Error setting expire for Redis key {key}: {str(e)}")
return False
async def clear(self) -> None:
"""
Clear the entire cache.
Warning: This is a destructive operation and will remove all data.
Only use in development/test environments.
"""
try:
redis = await self._get_redis()
await redis.flushdb()
logger.warning("Redis database flushed - all data cleared")
except Exception as e:
logger.error(f"Error clearing Redis database: {str(e)}")
async def keys(self, pattern: str = "*") -> List[str]:
"""
Retrieve keys that match a pattern.
Args:
pattern: Glob pattern to filter keys
Returns:
List of keys that match the pattern
"""
try:
redis = await self._get_redis()
return await redis.keys(pattern)
except Exception as e:
logger.error(f"Error getting Redis keys with pattern {pattern}: {str(e)}")
return []

View File

@ -3,12 +3,12 @@ import json
from typing import AsyncGenerator, Dict, Any
from fastapi import HTTPException
from datetime import datetime
from ..schemas.streaming import (
from src.schemas.streaming import (
JSONRPCRequest,
TaskStatusUpdateEvent,
)
from ..services.agent_runner import run_agent
from ..services.service_providers import (
from src.services.agent_runner import run_agent
from src.services.service_providers import (
session_service,
artifacts_service,
memory_service,
@ -25,31 +25,33 @@ class StreamingService:
agent_id: str,
api_key: str,
message: str,
contact_id: str = None,
session_id: str = None,
db: Session = None,
) -> AsyncGenerator[str, None]:
"""
Inicia o streaming de eventos SSE para uma tarefa.
Starts the SSE event streaming for a task.
Args:
agent_id: ID do agente
api_key: Chave de API para autenticação
message: Mensagem inicial
session_id: ID da sessão (opcional)
db: Sessão do banco de dados
agent_id: Agent ID
api_key: API key for authentication
message: Initial message
contact_id: Contact ID (optional)
session_id: Session ID (optional)
db: Database session
Yields:
Eventos SSE formatados
Formatted SSE events
"""
# Validação básica da API key
# Basic API key validation
if not api_key:
raise HTTPException(status_code=401, detail="API key é obrigatória")
raise HTTPException(status_code=401, detail="API key is required")
# Gera IDs únicos
task_id = str(uuid.uuid4())
# Generate unique IDs
task_id = contact_id or str(uuid.uuid4())
request_id = str(uuid.uuid4())
# Monta payload JSON-RPC
# Build JSON-RPC payload
payload = JSONRPCRequest(
id=request_id,
params={
@ -62,7 +64,7 @@ class StreamingService:
},
)
# Registra conexão
# Register connection
self.active_connections[task_id] = {
"agent_id": agent_id,
"api_key": api_key,
@ -70,7 +72,7 @@ class StreamingService:
}
try:
# Envia evento de início
# Send start event
yield self._format_sse_event(
"status",
TaskStatusUpdateEvent(
@ -80,10 +82,10 @@ class StreamingService:
).model_dump_json(),
)
# Executa o agente
# Execute the agent
result = await run_agent(
str(agent_id),
task_id,
contact_id or task_id,
message,
session_service,
artifacts_service,
@ -92,7 +94,7 @@ class StreamingService:
session_id,
)
# Envia a resposta do agente como um evento separado
# Send the agent's response as a separate event
yield self._format_sse_event(
"message",
json.dumps(
@ -104,7 +106,7 @@ class StreamingService:
),
)
# Evento de conclusão
# Completion event
yield self._format_sse_event(
"status",
TaskStatusUpdateEvent(
@ -114,7 +116,7 @@ class StreamingService:
)
except Exception as e:
# Evento de erro
# Error event
yield self._format_sse_event(
"status",
TaskStatusUpdateEvent(
@ -126,14 +128,14 @@ class StreamingService:
raise
finally:
# Limpa conexão
# Clean connection
self.active_connections.pop(task_id, None)
def _format_sse_event(self, event_type: str, data: str) -> str:
"""Formata um evento SSE."""
"""Format an SSE event."""
return f"event: {event_type}\ndata: {data}\n\n"
async def close_connection(self, task_id: str):
"""Fecha uma conexão de streaming."""
"""Close a streaming connection."""
if task_id in self.active_connections:
self.active_connections.pop(task_id)

110
src/utils/a2a_utils.py Normal file
View File

@ -0,0 +1,110 @@
"""
A2A protocol utilities.
This module contains utility functions for the A2A protocol implementation.
"""
import logging
from typing import List, Optional, Any, Dict
from src.schemas.a2a import (
ContentTypeNotSupportedError,
UnsupportedOperationError,
JSONRPCResponse,
)
logger = logging.getLogger(__name__)
def are_modalities_compatible(
server_output_modes: Optional[List[str]], client_output_modes: Optional[List[str]]
) -> bool:
"""
Check if server and client modalities are compatible.
Modalities are compatible if they are both non-empty
and there is at least one common element.
Args:
server_output_modes: List of output modes supported by the server
client_output_modes: List of output modes requested by the client
Returns:
True if compatible, False otherwise
"""
# If client doesn't specify modes, assume all are accepted
if client_output_modes is None or len(client_output_modes) == 0:
return True
# If server doesn't specify modes, assume all are supported
if server_output_modes is None or len(server_output_modes) == 0:
return True
# Check if there's at least one common mode
return any(mode in server_output_modes for mode in client_output_modes)
def new_incompatible_types_error(request_id: str) -> JSONRPCResponse:
"""
Create a JSON-RPC response for incompatible content types error.
Args:
request_id: The ID of the request that caused the error
Returns:
JSON-RPC response with ContentTypeNotSupportedError
"""
return JSONRPCResponse(id=request_id, error=ContentTypeNotSupportedError())
def new_not_implemented_error(request_id: str) -> JSONRPCResponse:
"""
Create a JSON-RPC response for unimplemented operation error.
Args:
request_id: The ID of the request that caused the error
Returns:
JSON-RPC response with UnsupportedOperationError
"""
return JSONRPCResponse(id=request_id, error=UnsupportedOperationError())
def create_task_id(agent_id: str, session_id: str, timestamp: str = None) -> str:
"""
Create a unique task ID for an agent and session.
Args:
agent_id: The ID of the agent
session_id: The ID of the session
timestamp: Optional timestamp to include in the ID
Returns:
A unique task ID
"""
import uuid
import time
timestamp = timestamp or str(int(time.time()))
unique = uuid.uuid4().hex[:8]
return f"{agent_id}_{session_id}_{timestamp}_{unique}"
def format_error_response(error: Exception, request_id: str = None) -> Dict[str, Any]:
"""
Format an exception as a JSON-RPC error response.
Args:
error: The exception to format
request_id: The ID of the request that caused the error
Returns:
JSON-RPC error response as dictionary
"""
from src.schemas.a2a import InternalError, JSONRPCResponse
error_response = JSONRPCResponse(
id=request_id, error=InternalError(message=str(error))
)
return error_response.model_dump(exclude_none=True)