feat(a2a): add Redis configuration and A2A service implementation
This commit is contained in:
parent
910979fb83
commit
4901be8e4c
3
.env
3
.env
@ -18,6 +18,9 @@ REDIS_HOST="localhost"
|
||||
REDIS_PORT=6379
|
||||
REDIS_DB=8
|
||||
REDIS_PASSWORD=""
|
||||
REDIS_SSL=false
|
||||
REDIS_KEY_PREFIX="a2a:"
|
||||
REDIS_TTL=3600
|
||||
|
||||
# Tools cache TTL in seconds (1 hour)
|
||||
TOOLS_CACHE_TTL=3600
|
||||
|
@ -18,6 +18,9 @@ REDIS_HOST="localhost"
|
||||
REDIS_PORT=6379
|
||||
REDIS_DB=0
|
||||
REDIS_PASSWORD="your-redis-password"
|
||||
REDIS_SSL=false
|
||||
REDIS_KEY_PREFIX="a2a:"
|
||||
REDIS_TTL=3600
|
||||
|
||||
# Tools cache TTL in seconds (1 hour)
|
||||
TOOLS_CACHE_TTL=3600
|
||||
@ -44,3 +47,9 @@ ADMIN_INITIAL_PASSWORD="strongpassword123"
|
||||
DEMO_EMAIL="demo@example.com"
|
||||
DEMO_PASSWORD="demo123"
|
||||
DEMO_CLIENT_NAME="Demo Client"
|
||||
|
||||
# A2A settings
|
||||
A2A_TASK_TTL=3600
|
||||
A2A_HISTORY_TTL=86400
|
||||
A2A_PUSH_NOTIFICATION_TTL=3600
|
||||
A2A_SSE_CLIENT_TTL=300
|
||||
|
252
a2a_checklist.md
252
a2a_checklist.md
@ -1,252 +0,0 @@
|
||||
# Checklist de Implementação do Protocolo A2A com Redis
|
||||
|
||||
## 1. Configuração Inicial
|
||||
|
||||
- [ ] **Configurar dependências no arquivo pyproject.toml**
|
||||
|
||||
- Adicionar Redis e dependências relacionadas:
|
||||
```
|
||||
redis = "^5.3.0"
|
||||
sse-starlette = "^2.3.3"
|
||||
jwcrypto = "^1.5.6"
|
||||
pyjwt = {extras = ["crypto"], version = "^2.10.1"}
|
||||
```
|
||||
|
||||
- [ ] **Configurar variáveis de ambiente para Redis**
|
||||
|
||||
- Adicionar em `.env.example` e `.env`:
|
||||
```
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=
|
||||
REDIS_DB=0
|
||||
REDIS_SSL=false
|
||||
REDIS_KEY_PREFIX=a2a:
|
||||
REDIS_TTL=3600
|
||||
```
|
||||
|
||||
- [ ] **Configurar Redis no docker-compose.yml**
|
||||
- Adicionar serviço Redis com portas e volumes apropriados
|
||||
- Configurar segurança básica (senha, se necessário)
|
||||
|
||||
## 2. Implementação de Modelos e Schemas
|
||||
|
||||
- [ ] **Criar schemas A2A em `src/schemas/a2a.py`**
|
||||
|
||||
- Implementar tipos conforme `docs/A2A/samples/python/common/types.py`:
|
||||
- Enums (TaskState, etc.)
|
||||
- Classes de mensagens (TextPart, FilePart, etc.)
|
||||
- Classes de tarefas (Task, TaskStatus, etc.)
|
||||
- Estruturas JSON-RPC
|
||||
- Tipos de erros
|
||||
|
||||
- [ ] **Implementar validadores de modelo**
|
||||
- Validadores para conteúdos de arquivo
|
||||
- Validadores para formatos de mensagem
|
||||
- Conversores de formato para compatibilidade com o protocolo
|
||||
|
||||
## 3. Implementação do Cache Redis
|
||||
|
||||
- [ ] **Criar configuração Redis em `src/config/redis.py`**
|
||||
|
||||
- Implementar função de conexão com pool
|
||||
- Configurar opções de segurança (SSL, autenticação)
|
||||
- Configurar TTL padrão para diferentes tipos de dados
|
||||
|
||||
- [ ] **Criar serviço de cache Redis em `src/services/redis_cache_service.py`**
|
||||
|
||||
- Implementar métodos do exemplo com suporte a Redis:
|
||||
```python
|
||||
class RedisCache:
|
||||
async def get_task(self, task_id: str) -> dict
|
||||
async def save_task(self, task_id: str, task_data: dict, ttl: int = 3600) -> None
|
||||
async def update_task_status(self, task_id: str, status: dict) -> bool
|
||||
async def append_to_history(self, task_id: str, message: dict) -> bool
|
||||
async def save_push_notification_config(self, task_id: str, config: dict) -> None
|
||||
async def get_push_notification_config(self, task_id: str) -> dict
|
||||
async def save_sse_client(self, task_id: str, client_id: str) -> None
|
||||
async def get_sse_clients(self, task_id: str) -> list
|
||||
async def remove_sse_client(self, task_id: str, client_id: str) -> None
|
||||
```
|
||||
|
||||
- [ ] **Implementar funcionalidades para gerenciamento de conexões**
|
||||
- Reconexão automática
|
||||
- Fallback para cache em memória em caso de falha
|
||||
- Métricas de desempenho
|
||||
|
||||
## 4. Serviços A2A
|
||||
|
||||
- [ ] **Implementar utilitários A2A em `src/utils/a2a_utils.py`**
|
||||
|
||||
- Implementar funções conforme `docs/A2A/samples/python/common/server/utils.py`:
|
||||
```python
|
||||
def are_modalities_compatible(server_output_modes, client_output_modes)
|
||||
def new_incompatible_types_error(request_id)
|
||||
def new_not_implemented_error(request_id)
|
||||
```
|
||||
|
||||
- [ ] **Implementar `A2ATaskManager` em `src/services/a2a_task_manager_service.py`**
|
||||
|
||||
- Seguir a interface do `TaskManager` do exemplo
|
||||
- Implementar todos os métodos abstratos:
|
||||
```python
|
||||
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse
|
||||
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse
|
||||
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse
|
||||
async def on_send_task_subscribe(self, request) -> Union[AsyncIterable, JSONRPCResponse]
|
||||
async def on_set_task_push_notification(self, request) -> SetTaskPushNotificationResponse
|
||||
async def on_get_task_push_notification(self, request) -> GetTaskPushNotificationResponse
|
||||
async def on_resubscribe_to_task(self, request) -> Union[AsyncIterable, JSONRPCResponse]
|
||||
```
|
||||
- Utilizar Redis para persistência de dados de tarefa
|
||||
|
||||
- [ ] **Implementar `A2AServer` em `src/services/a2a_server_service.py`**
|
||||
|
||||
- Processar requisições JSON-RPC conforme `docs/A2A/samples/python/common/server/server.py`:
|
||||
```python
|
||||
async def _process_request(self, request: Request)
|
||||
def _handle_exception(self, e: Exception) -> JSONResponse
|
||||
def _create_response(self, result: Any) -> Union[JSONResponse, EventSourceResponse]
|
||||
```
|
||||
|
||||
- [ ] **Integrar com agent_runner.py existente**
|
||||
|
||||
- Adaptar `run_agent` para uso no contexto de tarefas A2A
|
||||
- Implementar mapeamento entre formatos de mensagem
|
||||
|
||||
- [ ] **Integrar com streaming_service.py existente**
|
||||
- Adaptar para formato de eventos compatível com A2A
|
||||
- Implementar suporte a streaming de múltiplos tipos de eventos
|
||||
|
||||
## 5. Autenticação e Push Notifications
|
||||
|
||||
- [ ] **Implementar `PushNotificationAuth` em `src/services/push_notification_auth_service.py`**
|
||||
|
||||
- Seguir o exemplo em `docs/A2A/samples/python/common/utils/push_notification_auth.py`
|
||||
- Implementar:
|
||||
```python
|
||||
def generate_jwk(self)
|
||||
def handle_jwks_endpoint(self, request: Request)
|
||||
async def send_authenticated_push_notification(self, url: str, data: dict)
|
||||
```
|
||||
|
||||
- [ ] **Implementar verificação de URL de notificação**
|
||||
|
||||
- Seguir método `verify_push_notification_url` do exemplo
|
||||
- Implementar validação de token para verificação
|
||||
|
||||
- [ ] **Implementar armazenamento seguro de chaves**
|
||||
- Armazenar chaves privadas de forma segura
|
||||
- Rotação periódica de chaves
|
||||
- Gerenciamento do ciclo de vida das chaves
|
||||
|
||||
## 6. Rotas A2A
|
||||
|
||||
- [ ] **Implementar rotas em `src/api/a2a_routes.py`**
|
||||
|
||||
- Criar endpoint principal para processamento de requisições JSON-RPC:
|
||||
|
||||
```python
|
||||
@router.post("/{agent_id}")
|
||||
async def process_a2a_request(agent_id: str, request: Request, x_api_key: str = Header(None))
|
||||
```
|
||||
|
||||
- Implementar endpoint do Agent Card reutilizando lógica existente:
|
||||
|
||||
```python
|
||||
@router.get("/{agent_id}/.well-known/agent.json")
|
||||
async def get_agent_card(agent_id: str, db: Session = Depends(get_db))
|
||||
```
|
||||
|
||||
- Implementar endpoint JWKS para autenticação de push notifications:
|
||||
```python
|
||||
@router.get("/{agent_id}/.well-known/jwks.json")
|
||||
async def get_jwks(agent_id: str, db: Session = Depends(get_db))
|
||||
```
|
||||
|
||||
- [ ] **Registrar rotas A2A no aplicativo principal**
|
||||
- Adicionar importação e inclusão em `src/main.py`:
|
||||
```python
|
||||
app.include_router(a2a_routes.router, prefix="/api/v1")
|
||||
```
|
||||
|
||||
## 7. Testes
|
||||
|
||||
- [ ] **Criar testes unitários para schemas A2A**
|
||||
|
||||
- Testar validadores
|
||||
- Testar conversões de formato
|
||||
- Testar compatibilidade de modalidades
|
||||
|
||||
- [ ] **Criar testes unitários para cache Redis**
|
||||
|
||||
- Testar todas as operações CRUD
|
||||
- Testar expiração de dados
|
||||
- Testar comportamento com falhas de conexão
|
||||
|
||||
- [ ] **Criar testes unitários para gerenciador de tarefas**
|
||||
|
||||
- Testar ciclo de vida da tarefa
|
||||
- Testar cancelamento de tarefas
|
||||
- Testar notificações push
|
||||
|
||||
- [ ] **Criar testes de integração para endpoints A2A**
|
||||
- Testar requisições completas
|
||||
- Testar streaming
|
||||
- Testar cenários de erro
|
||||
|
||||
## 8. Segurança
|
||||
|
||||
- [ ] **Implementar validação de API key**
|
||||
|
||||
- Verificar API key para todas as requisições
|
||||
- Implementar rate limiting por agente/cliente
|
||||
|
||||
- [ ] **Configurar segurança no Redis**
|
||||
|
||||
- Ativar autenticação e SSL em produção
|
||||
- Definir políticas de retenção de dados
|
||||
- Implementar backup e recuperação
|
||||
|
||||
- [ ] **Configurar segurança para push notifications**
|
||||
- Implementar assinatura JWT
|
||||
- Validar URLs de callback
|
||||
- Implementar retry com backoff para falhas
|
||||
|
||||
## 9. Monitoramento e Métricas
|
||||
|
||||
- [ ] **Implementar métricas de Redis**
|
||||
|
||||
- Taxa de acertos/erros do cache
|
||||
- Tempo de resposta
|
||||
- Uso de memória
|
||||
|
||||
- [ ] **Implementar métricas de tarefas A2A**
|
||||
|
||||
- Número de tarefas por estado
|
||||
- Tempo médio de processamento
|
||||
- Taxa de erros
|
||||
|
||||
- [ ] **Configurar logging apropriado**
|
||||
- Registrar eventos importantes
|
||||
- Mascarar dados sensíveis
|
||||
- Implementar níveis de log configuráveis
|
||||
|
||||
## 10. Documentação
|
||||
|
||||
- [ ] **Documentar API A2A**
|
||||
|
||||
- Descrever endpoints e formatos
|
||||
- Fornecer exemplos de uso
|
||||
- Documentar erros e soluções
|
||||
|
||||
- [ ] **Documentar integração com Redis**
|
||||
|
||||
- Descrever configuração
|
||||
- Explicar estratégia de cache
|
||||
- Documentar TTLs e políticas de expiração
|
||||
|
||||
- [ ] **Criar exemplos de clients**
|
||||
- Implementar exemplos de uso em Python
|
||||
- Documentar fluxos comuns
|
||||
- Fornecer snippets para linguagens populares
|
187
a2a_client_test.py
Normal file
187
a2a_client_test.py
Normal file
@ -0,0 +1,187 @@
|
||||
import logging
|
||||
import httpx
|
||||
from httpx_sse import connect_sse
|
||||
from typing import Any, AsyncIterable, Optional
|
||||
from docs.A2A.samples.python.common.types import (
|
||||
AgentCard,
|
||||
GetTaskRequest,
|
||||
SendTaskRequest,
|
||||
SendTaskResponse,
|
||||
JSONRPCRequest,
|
||||
GetTaskResponse,
|
||||
CancelTaskResponse,
|
||||
CancelTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
SetTaskPushNotificationResponse,
|
||||
GetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationResponse,
|
||||
A2AClientHTTPError,
|
||||
A2AClientJSONError,
|
||||
SendTaskStreamingRequest,
|
||||
SendTaskStreamingResponse,
|
||||
)
|
||||
import json
|
||||
import asyncio
|
||||
import uuid
|
||||
|
||||
|
||||
# Configurar logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger("a2a_client_runner")
|
||||
|
||||
|
||||
class A2ACardResolver:
|
||||
def __init__(self, base_url, agent_card_path="/.well-known/agent.json"):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.agent_card_path = agent_card_path.lstrip("/")
|
||||
|
||||
def get_agent_card(self) -> AgentCard:
|
||||
with httpx.Client() as client:
|
||||
response = client.get(self.base_url + "/" + self.agent_card_path)
|
||||
response.raise_for_status()
|
||||
try:
|
||||
return AgentCard(**response.json())
|
||||
except json.JSONDecodeError as e:
|
||||
raise A2AClientJSONError(str(e)) from e
|
||||
|
||||
|
||||
class A2AClient:
|
||||
def __init__(
|
||||
self,
|
||||
agent_card: AgentCard = None,
|
||||
url: str = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
if agent_card:
|
||||
self.url = agent_card.url
|
||||
elif url:
|
||||
self.url = url
|
||||
else:
|
||||
raise ValueError("Must provide either agent_card or url")
|
||||
self.api_key = api_key
|
||||
self.headers = {"x-api-key": api_key} if api_key else {}
|
||||
|
||||
async def send_task(self, payload: dict[str, Any]) -> SendTaskResponse:
|
||||
request = SendTaskRequest(params=payload)
|
||||
return SendTaskResponse(**await self._send_request(request))
|
||||
|
||||
async def send_task_streaming(
|
||||
self, payload: dict[str, Any]
|
||||
) -> AsyncIterable[SendTaskStreamingResponse]:
|
||||
request = SendTaskStreamingRequest(params=payload)
|
||||
with httpx.Client(timeout=None) as client:
|
||||
with connect_sse(
|
||||
client,
|
||||
"POST",
|
||||
self.url,
|
||||
json=request.model_dump(),
|
||||
headers=self.headers,
|
||||
) as event_source:
|
||||
try:
|
||||
for sse in event_source.iter_sse():
|
||||
yield SendTaskStreamingResponse(**json.loads(sse.data))
|
||||
except json.JSONDecodeError as e:
|
||||
raise A2AClientJSONError(str(e)) from e
|
||||
except httpx.RequestError as e:
|
||||
raise A2AClientHTTPError(400, str(e)) from e
|
||||
|
||||
async def _send_request(self, request: JSONRPCRequest) -> dict[str, Any]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
# Image generation could take time, adding timeout
|
||||
response = await client.post(
|
||||
self.url,
|
||||
json=request.model_dump(),
|
||||
headers=self.headers,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
|
||||
except json.JSONDecodeError as e:
|
||||
raise A2AClientJSONError(str(e)) from e
|
||||
|
||||
async def get_task(self, payload: dict[str, Any]) -> GetTaskResponse:
|
||||
request = GetTaskRequest(params=payload)
|
||||
return GetTaskResponse(**await self._send_request(request))
|
||||
|
||||
async def cancel_task(self, payload: dict[str, Any]) -> CancelTaskResponse:
|
||||
request = CancelTaskRequest(params=payload)
|
||||
return CancelTaskResponse(**await self._send_request(request))
|
||||
|
||||
async def set_task_callback(
|
||||
self, payload: dict[str, Any]
|
||||
) -> SetTaskPushNotificationResponse:
|
||||
request = SetTaskPushNotificationRequest(params=payload)
|
||||
return SetTaskPushNotificationResponse(**await self._send_request(request))
|
||||
|
||||
async def get_task_callback(
|
||||
self, payload: dict[str, Any]
|
||||
) -> GetTaskPushNotificationResponse:
|
||||
request = GetTaskPushNotificationRequest(params=payload)
|
||||
return GetTaskPushNotificationResponse(**await self._send_request(request))
|
||||
|
||||
|
||||
async def main():
|
||||
# Configurações
|
||||
BASE_URL = "http://localhost:8000/api/v1/a2a/18a2889e-8573-4e70-833c-7d9e00a8fd80"
|
||||
API_KEY = "83c2c19f-dc2e-4abe-9a41-ef7d2eb079d6"
|
||||
|
||||
try:
|
||||
# Obter o card do agente
|
||||
logger.info("Obtendo card do agente...")
|
||||
card_resolver = A2ACardResolver(BASE_URL)
|
||||
try:
|
||||
card = card_resolver.get_agent_card()
|
||||
logger.info(f"Card do agente: {card}")
|
||||
except Exception as e:
|
||||
logger.error(f"Erro ao obter card do agente: {e}")
|
||||
return
|
||||
|
||||
# Criar cliente A2A com API key
|
||||
client = A2AClient(card, api_key=API_KEY)
|
||||
|
||||
# Exemplo 1: Enviar tarefa síncrona
|
||||
logger.info("\n=== TESTE DE TAREFA SÍNCRONA ===")
|
||||
task_id = str(uuid.uuid4())
|
||||
session_id = "test-session-1"
|
||||
|
||||
# Preparar payload da tarefa
|
||||
payload = {
|
||||
"id": task_id,
|
||||
"sessionId": session_id,
|
||||
"message": {
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Quais são os três maiores países do mundo em área territorial?",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(f"Enviando tarefa com ID: {task_id}")
|
||||
async for streaming_response in client.send_task_streaming(payload):
|
||||
if hasattr(streaming_response.result, "artifact"):
|
||||
# Processar conteúdo parcial
|
||||
print(streaming_response.result.artifact.parts[0].text)
|
||||
elif (
|
||||
hasattr(streaming_response.result, "status")
|
||||
and streaming_response.result.status.state == "completed"
|
||||
):
|
||||
# Tarefa concluída
|
||||
print(
|
||||
"Resposta final:",
|
||||
streaming_response.result.status.message.parts[0].text,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erro durante execução dos testes: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
491
a2a_feature.md
491
a2a_feature.md
@ -1,491 +0,0 @@
|
||||
# Implementação do Servidor A2A (Agent-to-Agent)
|
||||
|
||||
## Visão Geral
|
||||
|
||||
Este documento descreve o plano de implementação para integrar o servidor A2A (Agent-to-Agent) no sistema existente. A implementação seguirá os exemplos fornecidos nos arquivos de referência, adaptando-os à estrutura atual do projeto.
|
||||
|
||||
## Componentes a Serem Implementados
|
||||
|
||||
### 1. Servidor A2A
|
||||
|
||||
Implementação da classe `A2AServer` como serviço para gerenciar requisições JSON-RPC compatíveis com o protocolo A2A.
|
||||
|
||||
### 2. Gerenciador de Tarefas
|
||||
|
||||
Implementação do `TaskManager` como serviço para gerenciar o ciclo de vida das tarefas do agente.
|
||||
|
||||
### 3. Adaptadores para Integração
|
||||
|
||||
Criação de adaptadores para integrar o servidor A2A com os serviços existentes, como o streaming_service e push_notification_service.
|
||||
|
||||
## Rotas e Endpoints A2A
|
||||
|
||||
O protocolo A2A requer a implementação das seguintes rotas JSON-RPC:
|
||||
|
||||
### 1. `POST /a2a/{agent_id}`
|
||||
|
||||
Endpoint principal que processa todas as solicitações JSON-RPC para um agente específico.
|
||||
|
||||
### 2. `GET /a2a/{agent_id}/.well-known/agent.json`
|
||||
|
||||
Retorna o Agent Card contendo as informações do agente.
|
||||
|
||||
### Métodos JSON-RPC implementados:
|
||||
|
||||
1. **tasks/send** - Envia uma nova tarefa para o agente.
|
||||
|
||||
- Parâmetros: `id`, `sessionId`, `message`, `acceptedOutputModes` (opcional), `pushNotification` (opcional)
|
||||
- Retorna: Status da tarefa, artefatos gerados e histórico
|
||||
|
||||
2. **tasks/sendSubscribe** - Envia uma tarefa e assina para receber atualizações via streaming (SSE).
|
||||
|
||||
- Parâmetros: Mesmos do `tasks/send`
|
||||
- Retorna: Stream de eventos SSE com atualizações de status e artefatos
|
||||
|
||||
3. **tasks/get** - Obtém o status atual de uma tarefa.
|
||||
|
||||
- Parâmetros: `id`, `historyLength` (opcional)
|
||||
- Retorna: Status atual da tarefa, artefatos e histórico
|
||||
|
||||
4. **tasks/cancel** - Tenta cancelar uma tarefa em execução.
|
||||
|
||||
- Parâmetros: `id`
|
||||
- Retorna: Status atualizado da tarefa
|
||||
|
||||
5. **tasks/pushNotification/set** - Configura notificações push para uma tarefa.
|
||||
|
||||
- Parâmetros: `id`, `pushNotificationConfig` (URL e autenticação)
|
||||
- Retorna: Configuração de notificação atualizada
|
||||
|
||||
6. **tasks/pushNotification/get** - Obtém a configuração de notificações push de uma tarefa.
|
||||
|
||||
- Parâmetros: `id`
|
||||
- Retorna: Configuração de notificação atual
|
||||
|
||||
7. **tasks/resubscribe** - Reassina para receber eventos de uma tarefa existente.
|
||||
- Parâmetros: `id`
|
||||
- Retorna: Stream de eventos SSE
|
||||
|
||||
## Streaming e Push Notifications
|
||||
|
||||
### Streaming (SSE)
|
||||
|
||||
O streaming será implementado usando Server-Sent Events (SSE) para enviar atualizações em tempo real aos clientes.
|
||||
|
||||
#### Integração com streaming_service.py existente
|
||||
|
||||
O serviço atual `StreamingService` já implementa funcionalidades de streaming SSE. Iremos expandir e adaptar:
|
||||
|
||||
```python
|
||||
# Exemplo de integração com o streaming_service.py existente
|
||||
async def send_task_streaming(request: SendTaskStreamingRequest) -> AsyncIterable[SendTaskStreamingResponse]:
|
||||
stream_service = StreamingService()
|
||||
|
||||
async for event in stream_service.send_task_streaming(
|
||||
agent_id=request.params.metadata.get("agent_id"),
|
||||
api_key=api_key,
|
||||
message=request.params.message.parts[0].text,
|
||||
session_id=request.params.sessionId,
|
||||
db=db
|
||||
):
|
||||
# Converter formato de evento SSE para formato A2A
|
||||
yield SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
result=convert_to_a2a_event_format(event)
|
||||
)
|
||||
```
|
||||
|
||||
### Push Notifications
|
||||
|
||||
O sistema de notificações push permitirá que o agente envie atualizações para URLs de callback configuradas pelos clientes.
|
||||
|
||||
#### Integração com push_notification_service.py existente
|
||||
|
||||
O serviço atual `PushNotificationService` já implementa o envio de notificações. Iremos adaptar:
|
||||
|
||||
```python
|
||||
# Exemplo de integração com o push_notification_service.py existente
|
||||
async def send_push_notification(task_id, state, message=None):
|
||||
notification_config = await get_push_notification_config(task_id)
|
||||
if notification_config:
|
||||
await push_notification_service.send_notification(
|
||||
url=notification_config.url,
|
||||
task_id=task_id,
|
||||
state=state,
|
||||
message=message,
|
||||
headers=notification_config.headers
|
||||
)
|
||||
```
|
||||
|
||||
#### Autenticação de Push Notifications
|
||||
|
||||
Implementaremos autenticação segura para as notificações push baseada em JWT usando o `PushNotificationAuth`:
|
||||
|
||||
```python
|
||||
# Exemplo de como configurar autenticação nas notificações
|
||||
push_auth = PushNotificationSenderAuth()
|
||||
push_auth.generate_jwk()
|
||||
|
||||
# Incluir rota para obter as chaves públicas
|
||||
@router.get("/{agent_id}/.well-known/jwks.json")
|
||||
async def get_jwks(agent_id: str):
|
||||
return push_auth.handle_jwks_endpoint(request)
|
||||
|
||||
# Integrar autenticação ao enviar notificações
|
||||
async def send_authenticated_push_notification(url, data):
|
||||
await push_auth.send_push_notification(url, data)
|
||||
```
|
||||
|
||||
## Estratégia de Armazenamento de Dados
|
||||
|
||||
### Uso de Redis para Dados Temporários
|
||||
|
||||
Utilizaremos Redis para armazenamento e gerenciamento dos dados temporários das tarefas A2A, substituindo o cache em memória do exemplo original:
|
||||
|
||||
```python
|
||||
from src.services.redis_cache_service import RedisCacheService
|
||||
|
||||
class RedisCache:
|
||||
def __init__(self, redis_service: RedisCacheService):
|
||||
self.redis = redis_service
|
||||
|
||||
# Métodos para gerenciamento de tarefas
|
||||
async def get_task(self, task_id: str) -> dict:
|
||||
"""Recupera uma tarefa pelo ID."""
|
||||
return self.redis.get(f"task:{task_id}")
|
||||
|
||||
async def save_task(self, task_id: str, task_data: dict, ttl: int = 3600) -> None:
|
||||
"""Salva uma tarefa com TTL configurável."""
|
||||
self.redis.set(f"task:{task_id}", task_data, ttl=ttl)
|
||||
|
||||
async def update_task_status(self, task_id: str, status: dict) -> bool:
|
||||
"""Atualiza o status de uma tarefa."""
|
||||
task_data = await self.get_task(task_id)
|
||||
if not task_data:
|
||||
return False
|
||||
task_data["status"] = status
|
||||
await self.save_task(task_id, task_data)
|
||||
return True
|
||||
|
||||
# Métodos para histórico de tarefas
|
||||
async def append_to_history(self, task_id: str, message: dict) -> bool:
|
||||
"""Adiciona uma mensagem ao histórico da tarefa."""
|
||||
task_data = await self.get_task(task_id)
|
||||
if not task_data:
|
||||
return False
|
||||
|
||||
if "history" not in task_data:
|
||||
task_data["history"] = []
|
||||
|
||||
task_data["history"].append(message)
|
||||
await self.save_task(task_id, task_data)
|
||||
return True
|
||||
|
||||
# Métodos para notificações push
|
||||
async def save_push_notification_config(self, task_id: str, config: dict) -> None:
|
||||
"""Salva a configuração de notificação push para uma tarefa."""
|
||||
self.redis.set(f"push_notification:{task_id}", config, ttl=3600)
|
||||
|
||||
async def get_push_notification_config(self, task_id: str) -> dict:
|
||||
"""Recupera a configuração de notificação push de uma tarefa."""
|
||||
return self.redis.get(f"push_notification:{task_id}")
|
||||
|
||||
# Métodos para SSE (Server-Sent Events)
|
||||
async def save_sse_client(self, task_id: str, client_id: str) -> None:
|
||||
"""Registra um cliente SSE para uma tarefa."""
|
||||
self.redis.set_hash(f"sse_clients:{task_id}", client_id, "active")
|
||||
|
||||
async def get_sse_clients(self, task_id: str) -> list:
|
||||
"""Recupera todos os clientes SSE registrados para uma tarefa."""
|
||||
return self.redis.get_all_hash(f"sse_clients:{task_id}")
|
||||
|
||||
async def remove_sse_client(self, task_id: str, client_id: str) -> None:
|
||||
"""Remove um cliente SSE do registro."""
|
||||
self.redis.delete_hash(f"sse_clients:{task_id}", client_id)
|
||||
```
|
||||
|
||||
O serviço Redis será configurado com TTL (time-to-live) para garantir a limpeza automática de dados temporários:
|
||||
|
||||
```python
|
||||
# Configuração de TTL para diferentes tipos de dados
|
||||
TASK_TTL = 3600 # 1 hora para tarefas
|
||||
HISTORY_TTL = 86400 # 24 horas para histórico
|
||||
PUSH_NOTIFICATION_TTL = 3600 # 1 hora para configurações de notificação
|
||||
SSE_CLIENT_TTL = 300 # 5 minutos para clientes SSE
|
||||
```
|
||||
|
||||
### Modelos Existentes e Redis
|
||||
|
||||
O sistema continuará utilizando os modelos SQLAlchemy existentes para dados permanentes:
|
||||
|
||||
- **Agent**: Dados do agente e configurações
|
||||
- **Session**: Sessões persistentes
|
||||
- **MCPServer**: Configurações de servidores de ferramentas
|
||||
|
||||
Para dados temporários (tarefas A2A, histórico, streaming), utilizaremos Redis que oferece:
|
||||
|
||||
1. **Performance**: Operações em memória com persistência opcional
|
||||
2. **TTL**: Expiração automática de dados temporários
|
||||
3. **Estruturas de dados**: Suporte a strings, hashes, listas para diferentes necessidades
|
||||
4. **Pub/Sub**: Mecanismo para notificações em tempo real
|
||||
5. **Escalabilidade**: Melhor suporte a múltiplas instâncias do que cache em memória
|
||||
|
||||
### Implementação do TaskManager com Redis
|
||||
|
||||
O `A2ATaskManager` implementará a mesma interface que o `TaskManager` do exemplo, mas utilizando Redis:
|
||||
|
||||
```python
|
||||
class A2ATaskManager:
|
||||
"""
|
||||
Gerenciador de tarefas A2A usando Redis para armazenamento.
|
||||
Implementa a interface do protocolo A2A para gerenciamento do ciclo de vida das tarefas.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_cache: RedisCacheService,
|
||||
session_service=None,
|
||||
artifacts_service=None,
|
||||
memory_service=None,
|
||||
push_notification_service=None
|
||||
):
|
||||
self.redis_cache = redis_cache
|
||||
self.session_service = session_service
|
||||
self.artifacts_service = artifacts_service
|
||||
self.memory_service = memory_service
|
||||
self.push_notification_service = push_notification_service
|
||||
self.lock = asyncio.Lock()
|
||||
self.subscriber_lock = asyncio.Lock()
|
||||
self.task_sse_subscribers = {}
|
||||
|
||||
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
|
||||
"""
|
||||
Obtém o status atual de uma tarefa.
|
||||
|
||||
Args:
|
||||
request: Requisição JSON-RPC para obter dados da tarefa
|
||||
|
||||
Returns:
|
||||
Resposta com dados da tarefa ou erro
|
||||
"""
|
||||
logger.info(f"Getting task {request.params.id}")
|
||||
task_query_params = request.params
|
||||
|
||||
task_data = self.redis_cache.get(f"task:{task_query_params.id}")
|
||||
if not task_data:
|
||||
return GetTaskResponse(id=request.id, error=TaskNotFoundError())
|
||||
|
||||
# Processar histórico conforme solicitado
|
||||
if task_query_params.historyLength and task_data.get("history"):
|
||||
task_data["history"] = task_data["history"][-task_query_params.historyLength:]
|
||||
|
||||
return GetTaskResponse(id=request.id, result=task_data)
|
||||
```
|
||||
|
||||
## Implementação do A2A Server Service
|
||||
|
||||
O serviço `A2AServer` processará as requisições JSON-RPC conforme o protocolo A2A:
|
||||
|
||||
```python
|
||||
class A2AServer:
|
||||
"""
|
||||
Servidor A2A que implementa o protocolo JSON-RPC para processamento de tarefas de agentes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str = "/",
|
||||
agent_card = None,
|
||||
task_manager = None,
|
||||
streaming_service = None,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.agent_card = agent_card
|
||||
self.task_manager = task_manager
|
||||
self.streaming_service = streaming_service
|
||||
|
||||
async def _process_request(self, request: Request):
|
||||
"""
|
||||
Processa uma requisição JSON-RPC do protocolo A2A.
|
||||
|
||||
Args:
|
||||
request: Requisição HTTP
|
||||
|
||||
Returns:
|
||||
Resposta JSON-RPC ou stream de eventos
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
json_rpc_request = A2ARequest.validate_python(body)
|
||||
|
||||
# Delegar para o handler apropriado com base no tipo de requisição
|
||||
if isinstance(json_rpc_request, GetTaskRequest):
|
||||
result = await self.task_manager.on_get_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, SendTaskRequest):
|
||||
result = await self.task_manager.on_send_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, SendTaskStreamingRequest):
|
||||
result = await self.task_manager.on_send_task_subscribe(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, CancelTaskRequest):
|
||||
result = await self.task_manager.on_cancel_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, SetTaskPushNotificationRequest):
|
||||
result = await self.task_manager.on_set_task_push_notification(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, GetTaskPushNotificationRequest):
|
||||
result = await self.task_manager.on_get_task_push_notification(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, TaskResubscriptionRequest):
|
||||
result = await self.task_manager.on_resubscribe_to_task(json_rpc_request)
|
||||
else:
|
||||
logger.warning(f"Unexpected request type: {type(json_rpc_request)}")
|
||||
raise ValueError(f"Unexpected request type: {type(json_rpc_request)}")
|
||||
|
||||
return self._create_response(result)
|
||||
|
||||
except Exception as e:
|
||||
return self._handle_exception(e)
|
||||
```
|
||||
|
||||
## Implementação da Autenticação para Push Notifications
|
||||
|
||||
Implementaremos autenticação JWT para notificações push, seguindo o exemplo:
|
||||
|
||||
```python
|
||||
class PushNotificationAuth:
|
||||
def __init__(self):
|
||||
self.public_keys = []
|
||||
self.private_key_jwk = None
|
||||
|
||||
def generate_jwk(self):
|
||||
key = jwk.JWK.generate(kty='RSA', size=2048, kid=str(uuid.uuid4()), use="sig")
|
||||
self.public_keys.append(key.export_public(as_dict=True))
|
||||
self.private_key_jwk = PyJWK.from_json(key.export_private())
|
||||
|
||||
def handle_jwks_endpoint(self, request: Request):
|
||||
"""Retorna as chaves públicas para clientes."""
|
||||
return JSONResponse({"keys": self.public_keys})
|
||||
|
||||
async def send_authenticated_push_notification(self, url: str, data: dict):
|
||||
"""Envia notificação push assinada com JWT."""
|
||||
jwt_token = self._generate_jwt(data)
|
||||
headers = {'Authorization': f"Bearer {jwt_token}"}
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
try:
|
||||
response = await client.post(url, json=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
logger.info(f"Push-notification sent to URL: {url}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending push-notification to URL {url}: {e}")
|
||||
```
|
||||
|
||||
## Revisão das Rotas A2A
|
||||
|
||||
As rotas A2A serão implementadas em `src/api/a2a_routes.py`, utilizando a lógica de AgentCard do código existente:
|
||||
|
||||
```python
|
||||
@router.post("/{agent_id}")
|
||||
async def process_a2a_request(
|
||||
agent_id: str,
|
||||
request: Request,
|
||||
x_api_key: str = Header(None, alias="x-api-key"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Endpoint que processa requisições JSON-RPC do protocolo A2A.
|
||||
"""
|
||||
# Validar agente e API key
|
||||
agent = get_agent(db, agent_id)
|
||||
if not agent:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={"detail": "Agente não encontrado"}
|
||||
)
|
||||
|
||||
if agent.config.get("api_key") != x_api_key:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Chave API inválida"}
|
||||
)
|
||||
|
||||
# Criar Agent Card para o agente (reutilizando lógica existente)
|
||||
agent_card = create_agent_card_from_agent(agent, db)
|
||||
|
||||
# Configurar o servidor A2A para este agente
|
||||
a2a_server.agent_card = agent_card
|
||||
|
||||
# Processar a requisição A2A
|
||||
return await a2a_server._process_request(request)
|
||||
```
|
||||
|
||||
## Arquivos a Serem Criados/Atualizados
|
||||
|
||||
### Novos Arquivos
|
||||
|
||||
1. `src/schemas/a2a.py` - Modelos Pydantic para o protocolo A2A
|
||||
2. `src/services/redis_cache_service.py` - Serviço de cache Redis
|
||||
3. `src/config/redis.py` - Configuração do cliente Redis
|
||||
4. `src/utils/a2a_utils.py` - Utilitários para o protocolo A2A
|
||||
5. `src/services/a2a_task_manager_service.py` - Gerenciador de tarefas A2A
|
||||
6. `src/services/a2a_server_service.py` - Servidor A2A
|
||||
7. `src/services/push_notification_auth_service.py` - Autenticação para push notifications
|
||||
8. `src/api/a2a_routes.py` - Rotas para o protocolo A2A
|
||||
|
||||
### Arquivos a Serem Atualizados
|
||||
|
||||
1. `src/main.py` - Registrar novas rotas A2A
|
||||
2. `pyproject.toml` - Adicionar dependências (Redis, jwcrypto, etc.)
|
||||
|
||||
## Plano de Implementação
|
||||
|
||||
### Fase 1: Criação dos Esquemas
|
||||
|
||||
1. Criar arquivo `src/schemas/a2a.py` com os modelos Pydantic baseados no arquivo `common/types.py`
|
||||
2. Adaptar os tipos para a estrutura do projeto e adicionar suporte para streaming e push notifications
|
||||
|
||||
### Fase 2: Implementação do Serviço de Cache Redis
|
||||
|
||||
1. Criar arquivo `src/config/redis.py` para configuração do cliente Redis
|
||||
2. Criar arquivo `src/services/redis_cache_service.py` para gerenciamento de cache
|
||||
|
||||
### Fase 3: Implementação de Utilitários
|
||||
|
||||
1. Criar arquivo `src/utils/a2a_utils.py` com funções utilitárias baseadas em `common/server/utils.py`
|
||||
2. Adaptar o `PushNotificationAuth` para uso no contexto A2A
|
||||
|
||||
### Fase 4: Implementação do Gerenciador de Tarefas
|
||||
|
||||
1. Criar arquivo `src/services/a2a_task_manager_service.py` com a implementação do `A2ATaskManager`
|
||||
2. Integrar com serviços existentes:
|
||||
- agent_runner.py para execução de agentes
|
||||
- streaming_service.py para streaming SSE
|
||||
- push_notification_service.py para push notifications
|
||||
- redis_cache_service.py para cache de tarefas
|
||||
|
||||
### Fase 5: Implementação do Servidor A2A
|
||||
|
||||
1. Criar arquivo `src/services/a2a_server_service.py` com a implementação do `A2AServer`
|
||||
2. Implementar processamento de requisições JSON-RPC para todas as operações A2A
|
||||
|
||||
### Fase 6: Integração
|
||||
|
||||
1. Criar arquivo `src/api/a2a_routes.py` com rotas para o protocolo A2A
|
||||
2. Registrar as rotas no aplicativo FastAPI principal
|
||||
3. Assegurar que todas as operações A2A funcionem corretamente, incluindo streaming e push notifications
|
||||
|
||||
## Adaptações Necessárias
|
||||
|
||||
1. **Esquemas**: Adaptar os modelos do protocolo A2A para usar os esquemas Pydantic existentes quando possível
|
||||
2. **Autenticação**: Integrar com o sistema de autenticação existente usando API keys
|
||||
3. **Streaming**: Adaptar o `StreamingService` existente para o formato de eventos A2A
|
||||
4. **Push Notifications**: Integrar o `PushNotificationService` existente e adicionar suporte a autenticação JWT
|
||||
5. **Cache**: Utilizar Redis para armazenamento temporário de tarefas e eventos
|
||||
6. **Execução de Agentes**: Reutilizar o serviço existente `agent_runner.py` para execução
|
||||
|
||||
## Próximos Passos
|
||||
|
||||
1. Configurar dependências do Redis no projeto
|
||||
2. Implementar os esquemas em `src/schemas/a2a.py`
|
||||
3. Implementar o serviço de cache Redis
|
||||
4. Implementar as funções utilitárias em `src/utils/a2a_utils.py`
|
||||
5. Implementar o gerenciador de tarefas em `src/services/a2a_task_manager_service.py`
|
||||
6. Implementar o servidor A2A em `src/services/a2a_server_service.py`
|
||||
7. Implementar as rotas em `src/api/a2a_routes.py`
|
||||
8. Registrar as rotas no aplicativo principal `src/main.py`
|
||||
9. Testar a implementação com casos de uso completos, incluindo streaming e push notifications
|
@ -10,14 +10,21 @@ services:
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
- POSTGRES_CONNECTION_STRING=postgresql://postgres:postgres@postgres:5432/evo_ai
|
||||
- REDIS_HOST=redis
|
||||
- REDIS_PORT=6379
|
||||
- REDIS_PASSWORD=
|
||||
- JWT_SECRET_KEY=${JWT_SECRET_KEY}
|
||||
- SENDGRID_API_KEY=${SENDGRID_API_KEY}
|
||||
- EMAIL_FROM=${EMAIL_FROM}
|
||||
- APP_URL=${APP_URL}
|
||||
POSTGRES_CONNECTION_STRING: postgresql://postgres:postgres@postgres:5432/evo_ai
|
||||
REDIS_HOST: redis
|
||||
REDIS_PORT: 6379
|
||||
REDIS_PASSWORD: ""
|
||||
REDIS_SSL: "false"
|
||||
REDIS_KEY_PREFIX: "a2a:"
|
||||
REDIS_TTL: 3600
|
||||
A2A_TASK_TTL: 3600
|
||||
A2A_HISTORY_TTL: 86400
|
||||
A2A_PUSH_NOTIFICATION_TTL: 3600
|
||||
A2A_SSE_CLIENT_TTL: 300
|
||||
JWT_SECRET_KEY: ${JWT_SECRET_KEY}
|
||||
SENDGRID_API_KEY: ${SENDGRID_API_KEY}
|
||||
EMAIL_FROM: ${EMAIL_FROM}
|
||||
APP_URL: ${APP_URL}
|
||||
volumes:
|
||||
- ./logs:/app/logs
|
||||
restart: unless-stopped
|
||||
@ -26,9 +33,9 @@ services:
|
||||
image: postgres:14-alpine
|
||||
container_name: evo-ai-postgres
|
||||
environment:
|
||||
- POSTGRES_USER=postgres
|
||||
- POSTGRES_PASSWORD=postgres
|
||||
- POSTGRES_DB=evo_ai
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: evo_ai
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
@ -42,6 +49,12 @@ services:
|
||||
- "6379:6379"
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
command: redis-server --appendonly yes
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 5s
|
||||
timeout: 30s
|
||||
retries: 50
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
|
@ -43,6 +43,10 @@ dependencies = [
|
||||
"pydantic[email]==2.11.3",
|
||||
"httpx==0.28.1",
|
||||
"httpx-sse==0.4.0",
|
||||
"redis==5.3.0",
|
||||
"sse-starlette==2.3.3",
|
||||
"jwcrypto==1.5.6",
|
||||
"pyjwt[crypto]==2.9.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
395
src/api/a2a_routes.py
Normal file
395
src/api/a2a_routes.py
Normal file
@ -0,0 +1,395 @@
|
||||
"""
|
||||
Routes for the A2A (Agent-to-Agent) protocol.
|
||||
|
||||
This module implements the standard A2A routes according to the specification.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Header, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from src.config.database import get_db
|
||||
from src.services import agent_service
|
||||
from src.services import (
|
||||
RedisCacheService,
|
||||
AgentRunnerAdapter,
|
||||
StreamingServiceAdapter,
|
||||
create_agent_card_from_agent,
|
||||
)
|
||||
from src.services.a2a_task_manager_service import A2ATaskManager
|
||||
from src.services.a2a_server_service import A2AServer
|
||||
from src.services.agent_runner import run_agent
|
||||
from src.services.service_providers import (
|
||||
session_service,
|
||||
artifacts_service,
|
||||
memory_service,
|
||||
)
|
||||
from src.services.push_notification_service import push_notification_service
|
||||
from src.services.push_notification_auth_service import push_notification_auth
|
||||
from src.services.streaming_service import StreamingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create router with prefix /a2a according to the standard protocol
|
||||
router = APIRouter(
|
||||
prefix="/a2a",
|
||||
tags=["a2a"],
|
||||
responses={
|
||||
404: {"description": "Not found"},
|
||||
400: {"description": "Bad request"},
|
||||
401: {"description": "Unauthorized"},
|
||||
500: {"description": "Internal server error"},
|
||||
},
|
||||
)
|
||||
|
||||
# Singleton instances for shared resources
|
||||
streaming_service = StreamingService()
|
||||
redis_cache_service = RedisCacheService()
|
||||
streaming_adapter = StreamingServiceAdapter(streaming_service)
|
||||
|
||||
# Cache dictionary para manter instâncias de A2ATaskManager por agente
|
||||
# Isso evita criar novas instâncias a cada request
|
||||
_task_manager_cache = {}
|
||||
_agent_runner_cache = {}
|
||||
|
||||
|
||||
def get_agent_runner_adapter(db=None, reuse=True, agent_id=None):
|
||||
"""
|
||||
Get or create an agent runner adapter.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
reuse: Whether to reuse an existing instance
|
||||
agent_id: Agent ID to use as cache key
|
||||
|
||||
Returns:
|
||||
Agent runner adapter instance
|
||||
"""
|
||||
cache_key = str(agent_id) if agent_id else "default"
|
||||
logger.info(
|
||||
f"[DEBUG] get_agent_runner_adapter chamado para agent_id={agent_id}, reuse={reuse}, cache_key={cache_key}"
|
||||
)
|
||||
|
||||
if reuse and cache_key in _agent_runner_cache:
|
||||
adapter = _agent_runner_cache[cache_key]
|
||||
logger.info(
|
||||
f"[DEBUG] Reutilizando agent_runner_adapter existente para {cache_key}"
|
||||
)
|
||||
# Atualizar a sessão DB se fornecida
|
||||
if db is not None:
|
||||
adapter.db = db
|
||||
return adapter
|
||||
|
||||
logger.info(
|
||||
f"[IMPORTANTE] Criando NOVA instância de AgentRunnerAdapter para {cache_key}"
|
||||
)
|
||||
adapter = AgentRunnerAdapter(
|
||||
agent_runner_func=run_agent,
|
||||
session_service=session_service,
|
||||
artifacts_service=artifacts_service,
|
||||
memory_service=memory_service,
|
||||
db=db,
|
||||
)
|
||||
|
||||
if reuse:
|
||||
logger.info(f"[DEBUG] Armazenando nova instância no cache para {cache_key}")
|
||||
_agent_runner_cache[cache_key] = adapter
|
||||
|
||||
return adapter
|
||||
|
||||
|
||||
def get_task_manager(agent_id, db=None, reuse=True, operation_type="query"):
|
||||
cache_key = str(agent_id)
|
||||
|
||||
# Para operações de consulta, NUNCA crie um agent_runner
|
||||
if operation_type == "query":
|
||||
if cache_key in _task_manager_cache:
|
||||
# Reutilize existente
|
||||
task_manager = _task_manager_cache[cache_key]
|
||||
task_manager.db = db
|
||||
return task_manager
|
||||
|
||||
# Se não existe, crie um task_manager SEM agent_runner para consultas
|
||||
return A2ATaskManager(
|
||||
redis_cache=redis_cache_service,
|
||||
agent_runner=None, # Sem agent_runner para consultas!
|
||||
streaming_service=streaming_adapter,
|
||||
push_notification_service=push_notification_service,
|
||||
db=db,
|
||||
)
|
||||
|
||||
# Para operações de execução, use o fluxo normal
|
||||
if reuse and cache_key in _task_manager_cache:
|
||||
# Atualize o db
|
||||
task_manager = _task_manager_cache[cache_key]
|
||||
task_manager.db = db
|
||||
return task_manager
|
||||
|
||||
# Create new
|
||||
agent_runner_adapter = get_agent_runner_adapter(
|
||||
db=db, reuse=reuse, agent_id=agent_id
|
||||
)
|
||||
task_manager = A2ATaskManager(
|
||||
redis_cache=redis_cache_service,
|
||||
agent_runner=agent_runner_adapter,
|
||||
streaming_service=streaming_adapter,
|
||||
push_notification_service=push_notification_service,
|
||||
db=db,
|
||||
)
|
||||
_task_manager_cache[cache_key] = task_manager
|
||||
return task_manager
|
||||
|
||||
|
||||
@router.post("/{agent_id}")
|
||||
async def process_a2a_request(
|
||||
agent_id: uuid.UUID,
|
||||
request: Request,
|
||||
x_api_key: str = Header(None, alias="x-api-key"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Main endpoint for processing JSON-RPC requests of the A2A protocol.
|
||||
|
||||
This endpoint processes all JSON-RPC methods of the A2A protocol, including:
|
||||
- tasks/send: Sending tasks
|
||||
- tasks/sendSubscribe: Sending tasks with streaming
|
||||
- tasks/get: Querying task status
|
||||
- tasks/cancel: Cancelling tasks
|
||||
- tasks/pushNotification/set: Setting push notifications
|
||||
- tasks/pushNotification/get: Querying push notification configurations
|
||||
- tasks/resubscribe: Resubscribing to receive task updates
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
request: HTTP request with JSON-RPC payload
|
||||
x_api_key: API key for authentication
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
JSON-RPC response or streaming (SSE) depending on the method
|
||||
"""
|
||||
try:
|
||||
# Detailed request log
|
||||
logger.info(f"Request received for A2A agent {agent_id}")
|
||||
logger.info(f"Headers: {dict(request.headers)}")
|
||||
|
||||
try:
|
||||
body = await request.json()
|
||||
method = body.get("method", "unknown")
|
||||
logger.info(f"[IMPORTANTE] Método solicitado: {method}")
|
||||
logger.info(f"Request body: {body}")
|
||||
|
||||
# Determinar se é uma solicitação de consulta (get_task) ou execução (send_task)
|
||||
is_query_request = method in [
|
||||
"tasks/get",
|
||||
"tasks/cancel",
|
||||
"tasks/pushNotification/get",
|
||||
"tasks/resubscribe",
|
||||
]
|
||||
# Para consultas, reutilizamos os componentes; para execuções,
|
||||
# criamos novos para garantir estado limpo
|
||||
reuse_components = is_query_request
|
||||
logger.info(
|
||||
f"[IMPORTANTE] Is query request: {is_query_request}, Reuse components: {reuse_components}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading request body: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": None,
|
||||
"error": {
|
||||
"code": -32700,
|
||||
"message": f"Parse error: {str(e)}",
|
||||
"data": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Verify if the agent exists
|
||||
agent = agent_service.get_agent(db, agent_id)
|
||||
if agent is None:
|
||||
logger.warning(f"Agent not found: {agent_id}")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": None,
|
||||
"error": {"code": 404, "message": "Agent not found", "data": None},
|
||||
},
|
||||
)
|
||||
|
||||
# Verify API key
|
||||
agent_config = agent.config
|
||||
logger.info(f"Received API Key: {x_api_key}")
|
||||
logger.info(f"Expected API Key: {agent_config.get('api_key')}")
|
||||
|
||||
if x_api_key and agent_config.get("api_key") != x_api_key:
|
||||
logger.warning(f"Invalid API Key for agent {agent_id}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": None,
|
||||
"error": {"code": 401, "message": "Invalid API key", "data": None},
|
||||
},
|
||||
)
|
||||
|
||||
# Obter o task manager para este agente (reutilizando se possível)
|
||||
a2a_task_manager = get_task_manager(
|
||||
agent_id,
|
||||
db=db,
|
||||
reuse=reuse_components,
|
||||
operation_type="query" if is_query_request else "execution",
|
||||
)
|
||||
a2a_server = A2AServer(task_manager=a2a_task_manager)
|
||||
|
||||
# Configure agent_card for the A2A server
|
||||
logger.info("Configuring agent_card for A2A server")
|
||||
agent_card = create_agent_card_from_agent(agent, db)
|
||||
a2a_server.agent_card = agent_card
|
||||
|
||||
# Verify JSON-RPC format
|
||||
if not body.get("jsonrpc") or body.get("jsonrpc") != "2.0":
|
||||
logger.error(f"Invalid JSON-RPC format: {body.get('jsonrpc')}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": body.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": "Invalid Request: jsonrpc must be '2.0'",
|
||||
"data": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the method
|
||||
if not body.get("method"):
|
||||
logger.error("Method not specified in request")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": body.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": "Invalid Request: method is required",
|
||||
"data": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Processing method: {body.get('method')}")
|
||||
|
||||
# Process the request with the A2A server
|
||||
logger.info("Sending request to A2A server")
|
||||
|
||||
# Pass the agent_id and db directly to the process_request method
|
||||
return await a2a_server.process_request(request, agent_id=str(agent_id), db=db)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing A2A request: {str(e)}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": None,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": "Internal server error",
|
||||
"data": {"detail": str(e)},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/.well-known/agent.json")
|
||||
async def get_agent_card(
|
||||
agent_id: uuid.UUID,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Endpoint to get the Agent Card in the .well-known format of the A2A protocol.
|
||||
|
||||
This endpoint returns the agent information in the standard A2A format,
|
||||
including capabilities, authentication information, and skills.
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
request: HTTP request
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Agent Card in JSON format
|
||||
"""
|
||||
try:
|
||||
agent = agent_service.get_agent(db, agent_id)
|
||||
if agent is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found"
|
||||
)
|
||||
|
||||
agent_card = create_agent_card_from_agent(agent, db)
|
||||
|
||||
# Obter o task manager para este agente (reutilizando se possível)
|
||||
a2a_task_manager = get_task_manager(agent_id, db=db, reuse=True)
|
||||
a2a_server = A2AServer(task_manager=a2a_task_manager)
|
||||
|
||||
# Configure the A2A server with the agent card
|
||||
a2a_server.agent_card = agent_card
|
||||
|
||||
# Use the A2A server to deliver the agent card, ensuring protocol compatibility
|
||||
return await a2a_server.get_agent_card(request, db=db)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating agent card: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error generating agent card",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/.well-known/jwks.json")
|
||||
async def get_jwks(
|
||||
agent_id: uuid.UUID,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Endpoint to get the public JWKS keys for verifying the authenticity
|
||||
of push notifications.
|
||||
|
||||
Clients can use these keys to verify the authenticity of received notifications.
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
request: HTTP request
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
JSON with the public keys in JWKS format
|
||||
"""
|
||||
try:
|
||||
# Verify if the agent exists
|
||||
agent = agent_service.get_agent(db, agent_id)
|
||||
if agent is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found"
|
||||
)
|
||||
|
||||
# Return the public keys
|
||||
return push_notification_auth.handle_jwks_endpoint(request)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error obtaining JWKS: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error obtaining JWKS",
|
||||
)
|
@ -1,6 +1,3 @@
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
import os
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Header
|
||||
from sqlalchemy.orm import Session
|
||||
from src.config.database import get_db
|
||||
@ -18,19 +15,7 @@ from src.services import (
|
||||
agent_service,
|
||||
mcp_server_service,
|
||||
)
|
||||
from src.services.agent_runner import run_agent
|
||||
from src.services.service_providers import (
|
||||
session_service,
|
||||
artifacts_service,
|
||||
memory_service,
|
||||
)
|
||||
from src.services.push_notification_service import push_notification_service
|
||||
import logging
|
||||
from fastapi.responses import StreamingResponse
|
||||
from ..services.streaming_service import StreamingService
|
||||
from ..schemas.streaming import JSONRPCRequest
|
||||
|
||||
from src.services.session_service import get_session_events
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -79,8 +64,6 @@ router = APIRouter(
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
streaming_service = StreamingService()
|
||||
|
||||
|
||||
@router.post("/", response_model=Agent, status_code=status.HTTP_201_CREATED)
|
||||
async def create_agent(
|
||||
@ -172,348 +155,3 @@ async def delete_agent(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/.well-known/agent.json")
|
||||
async def get_agent_json(
|
||||
agent_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
agent = agent_service.get_agent(db, agent_id)
|
||||
if agent is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found"
|
||||
)
|
||||
|
||||
mcp_servers = agent.config.get("mcp_servers", [])
|
||||
formatted_tools = await format_agent_tools(mcp_servers, db)
|
||||
|
||||
AGENT_CARD = {
|
||||
"name": agent.name,
|
||||
"description": agent.description,
|
||||
"url": f"{os.getenv('API_URL', '')}/api/v1/agents/{agent.id}",
|
||||
"provider": {
|
||||
"organization": os.getenv("ORGANIZATION_NAME", ""),
|
||||
"url": os.getenv("ORGANIZATION_URL", ""),
|
||||
},
|
||||
"version": os.getenv("API_VERSION", ""),
|
||||
"capabilities": {
|
||||
"streaming": True,
|
||||
"pushNotifications": True,
|
||||
"stateTransitionHistory": True,
|
||||
},
|
||||
"authentication": {
|
||||
"schemes": ["apiKey"],
|
||||
"credentials": {"in": "header", "name": "x-api-key"},
|
||||
},
|
||||
"defaultInputModes": ["text", "application/json"],
|
||||
"defaultOutputModes": ["text", "application/json"],
|
||||
"skills": formatted_tools,
|
||||
}
|
||||
return AGENT_CARD
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating agent card: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error generating agent card",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{agent_id}/tasks/send")
|
||||
async def handle_task(
|
||||
agent_id: uuid.UUID,
|
||||
request: JSONRPCRequest,
|
||||
x_api_key: str = Header(..., alias="x-api-key"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Endpoint to clients A2A send a new task (with an initial user message)."""
|
||||
try:
|
||||
# Verify agent
|
||||
agent = agent_service.get_agent(db, agent_id)
|
||||
if agent is None:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.id,
|
||||
"error": {"code": 404, "message": "Agent not found", "data": None},
|
||||
}
|
||||
|
||||
# Verify API key
|
||||
agent_config = agent.config
|
||||
if agent_config.get("api_key") != x_api_key:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.id,
|
||||
"error": {"code": 401, "message": "Invalid API key", "data": None},
|
||||
}
|
||||
|
||||
# Extract task request from JSON-RPC params
|
||||
task_request = request.params
|
||||
|
||||
# Validate required fields
|
||||
task_id = task_request.get("id")
|
||||
if not task_id:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.id,
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "Invalid parameters",
|
||||
"data": {"detail": "Task ID is required"},
|
||||
},
|
||||
}
|
||||
|
||||
# Extract user message
|
||||
try:
|
||||
user_message = task_request["message"]["parts"][0]["text"]
|
||||
except (KeyError, IndexError):
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.id,
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "Invalid parameters",
|
||||
"data": {"detail": "Invalid message format"},
|
||||
},
|
||||
}
|
||||
|
||||
# Configure session and metadata
|
||||
session_id = f"{task_id}_{agent_id}"
|
||||
metadata = task_request.get("metadata", {})
|
||||
history_length = metadata.get("historyLength", 50)
|
||||
|
||||
# Initialize response
|
||||
response_task = {
|
||||
"id": task_id,
|
||||
"sessionId": session_id,
|
||||
"status": {
|
||||
"state": "submitted",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": None,
|
||||
"error": None,
|
||||
},
|
||||
"artifacts": [],
|
||||
"history": [],
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
# Handle push notification configuration
|
||||
push_notification = task_request.get("pushNotification")
|
||||
if push_notification:
|
||||
url = push_notification.get("url")
|
||||
headers = push_notification.get("headers", {})
|
||||
|
||||
if not url:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.id,
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "Invalid parameters",
|
||||
"data": {"detail": "Push notification URL is required"},
|
||||
},
|
||||
}
|
||||
|
||||
# Store push notification config in metadata
|
||||
response_task["metadata"]["pushNotification"] = {
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
}
|
||||
|
||||
# Send initial notification
|
||||
asyncio.create_task(
|
||||
push_notification_service.send_notification(
|
||||
url=url, task_id=task_id, state="submitted", headers=headers
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Update status to running
|
||||
response_task["status"].update(
|
||||
{"state": "running", "timestamp": datetime.now().isoformat()}
|
||||
)
|
||||
|
||||
# Send running notification if configured
|
||||
if push_notification:
|
||||
asyncio.create_task(
|
||||
push_notification_service.send_notification(
|
||||
url=url, task_id=task_id, state="running", headers=headers
|
||||
)
|
||||
)
|
||||
|
||||
# Execute agent
|
||||
final_response_text = await run_agent(
|
||||
str(agent_id),
|
||||
task_id,
|
||||
user_message,
|
||||
session_service,
|
||||
artifacts_service,
|
||||
memory_service,
|
||||
db,
|
||||
session_id,
|
||||
)
|
||||
|
||||
# Update status to completed
|
||||
response_task["status"].update(
|
||||
{
|
||||
"state": "completed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": {
|
||||
"role": "agent",
|
||||
"parts": [{"type": "text", "text": final_response_text}],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Add artifacts
|
||||
if final_response_text:
|
||||
response_task["artifacts"].append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": final_response_text,
|
||||
"metadata": {
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"content_type": "text/plain",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Send completed notification if configured
|
||||
if push_notification:
|
||||
asyncio.create_task(
|
||||
push_notification_service.send_notification(
|
||||
url=url,
|
||||
task_id=task_id,
|
||||
state="completed",
|
||||
message={
|
||||
"role": "agent",
|
||||
"parts": [{"type": "text", "text": final_response_text}],
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Update status to failed
|
||||
response_task["status"].update(
|
||||
{
|
||||
"state": "failed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"error": {"code": "AGENT_EXECUTION_ERROR", "message": str(e)},
|
||||
}
|
||||
)
|
||||
|
||||
# Send failed notification if configured
|
||||
if push_notification:
|
||||
asyncio.create_task(
|
||||
push_notification_service.send_notification(
|
||||
url=url,
|
||||
task_id=task_id,
|
||||
state="failed",
|
||||
message={
|
||||
"role": "system",
|
||||
"parts": [{"type": "text", "text": str(e)}],
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
)
|
||||
|
||||
# Process history
|
||||
try:
|
||||
history_messages = get_session_events(session_service, session_id)
|
||||
history_messages = history_messages[-history_length:]
|
||||
|
||||
formatted_history = []
|
||||
for event in history_messages:
|
||||
if event.content and event.content.parts:
|
||||
role = (
|
||||
"agent" if event.content.role == "model" else event.content.role
|
||||
)
|
||||
formatted_history.append(
|
||||
{
|
||||
"role": role,
|
||||
"parts": [
|
||||
{"type": "text", "text": part.text}
|
||||
for part in event.content.parts
|
||||
if part.text
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
response_task["history"] = formatted_history
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing history: {str(e)}")
|
||||
|
||||
# Return JSON-RPC response
|
||||
return {"jsonrpc": "2.0", "id": request.id, "result": response_task}
|
||||
|
||||
except HTTPException as e:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.id,
|
||||
"error": {"code": e.status_code, "message": e.detail, "data": None},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in handle_task: {str(e)}")
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.id,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": "Internal server error",
|
||||
"data": {"detail": str(e)},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{agent_id}/tasks/sendSubscribe")
|
||||
async def subscribe_task_streaming(
|
||||
agent_id: str,
|
||||
request: JSONRPCRequest,
|
||||
x_api_key: str = Header(None),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Endpoint para streaming de eventos SSE de uma tarefa.
|
||||
|
||||
Args:
|
||||
agent_id: ID do agente
|
||||
request: Requisição JSON-RPC
|
||||
x_api_key: Chave de API no header
|
||||
db: Sessão do banco de dados
|
||||
|
||||
Returns:
|
||||
StreamingResponse com eventos SSE
|
||||
"""
|
||||
if not x_api_key:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.id,
|
||||
"error": {"code": 401, "message": "API key é obrigatória", "data": None},
|
||||
}
|
||||
|
||||
# Extrai mensagem do payload
|
||||
message = request.params.get("message", {}).get("parts", [{}])[0].get("text", "")
|
||||
session_id = request.params.get("sessionId")
|
||||
|
||||
# Configura streaming
|
||||
async def event_generator():
|
||||
async for event in streaming_service.send_task_streaming(
|
||||
agent_id=agent_id,
|
||||
api_key=x_api_key,
|
||||
message=message,
|
||||
session_id=session_id,
|
||||
db=db,
|
||||
):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
85
src/config/redis.py
Normal file
85
src/config/redis.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""
|
||||
Redis configuration module.
|
||||
|
||||
This module defines the Redis connection settings and provides
|
||||
function to create a Redis connection pool for the application.
|
||||
"""
|
||||
|
||||
import os
|
||||
import redis
|
||||
from dotenv import load_dotenv
|
||||
import logging
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_redis_config():
|
||||
"""
|
||||
Get Redis configuration from environment variables.
|
||||
|
||||
Returns:
|
||||
dict: Redis configuration parameters
|
||||
"""
|
||||
return {
|
||||
"host": os.getenv("REDIS_HOST", "localhost"),
|
||||
"port": int(os.getenv("REDIS_PORT", 6379)),
|
||||
"db": int(os.getenv("REDIS_DB", 0)),
|
||||
"password": os.getenv("REDIS_PASSWORD", None),
|
||||
"ssl": os.getenv("REDIS_SSL", "false").lower() == "true",
|
||||
"key_prefix": os.getenv("REDIS_KEY_PREFIX", "a2a:"),
|
||||
"default_ttl": int(os.getenv("REDIS_TTL", 3600)),
|
||||
}
|
||||
|
||||
|
||||
def get_a2a_config():
|
||||
"""
|
||||
Get A2A-specific cache TTL values from environment variables.
|
||||
|
||||
Returns:
|
||||
dict: A2A TTL configuration parameters
|
||||
"""
|
||||
return {
|
||||
"task_ttl": int(os.getenv("A2A_TASK_TTL", 3600)),
|
||||
"history_ttl": int(os.getenv("A2A_HISTORY_TTL", 86400)),
|
||||
"push_notification_ttl": int(os.getenv("A2A_PUSH_NOTIFICATION_TTL", 3600)),
|
||||
"sse_client_ttl": int(os.getenv("A2A_SSE_CLIENT_TTL", 300)),
|
||||
}
|
||||
|
||||
|
||||
def create_redis_pool(config=None):
|
||||
"""
|
||||
Create and return a Redis connection pool.
|
||||
|
||||
Args:
|
||||
config (dict, optional): Redis configuration. If None,
|
||||
configuration is loaded from environment
|
||||
|
||||
Returns:
|
||||
redis.ConnectionPool: Redis connection pool
|
||||
"""
|
||||
if config is None:
|
||||
config = get_redis_config()
|
||||
|
||||
try:
|
||||
connection_pool = redis.ConnectionPool(
|
||||
host=config["host"],
|
||||
port=config["port"],
|
||||
db=config["db"],
|
||||
password=config["password"] if config["password"] else None,
|
||||
ssl=config["ssl"],
|
||||
decode_responses=True,
|
||||
)
|
||||
# Test the connection
|
||||
redis_client = redis.Redis(connection_pool=connection_pool)
|
||||
redis_client.ping()
|
||||
logger.info(
|
||||
f"Redis connection successful: {config['host']}:{config['port']}, "
|
||||
f"db={config['db']}, ssl={config['ssl']}"
|
||||
)
|
||||
return connection_pool
|
||||
except redis.RedisError as e:
|
||||
logger.error(f"Redis connection error: {e}")
|
||||
raise
|
@ -37,6 +37,9 @@ class Settings(BaseSettings):
|
||||
REDIS_PORT: int = int(os.getenv("REDIS_PORT", 6379))
|
||||
REDIS_DB: int = int(os.getenv("REDIS_DB", 0))
|
||||
REDIS_PASSWORD: Optional[str] = os.getenv("REDIS_PASSWORD")
|
||||
REDIS_SSL: bool = os.getenv("REDIS_SSL", "false").lower() == "true"
|
||||
REDIS_KEY_PREFIX: str = os.getenv("REDIS_KEY_PREFIX", "evoai:")
|
||||
REDIS_TTL: int = int(os.getenv("REDIS_TTL", 3600))
|
||||
|
||||
# Tool cache TTL in seconds (1 hour)
|
||||
TOOLS_CACHE_TTL: int = int(os.getenv("TOOLS_CACHE_TTL", 3600))
|
||||
|
@ -22,6 +22,7 @@ import src.api.contact_routes
|
||||
import src.api.mcp_server_routes
|
||||
import src.api.tool_routes
|
||||
import src.api.client_routes
|
||||
import src.api.a2a_routes
|
||||
|
||||
# Add the root directory to PYTHONPATH
|
||||
root_dir = Path(__file__).parent.parent
|
||||
@ -40,13 +41,13 @@ app = FastAPI(
|
||||
# CORS configuration
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Permite todas as origens em desenvolvimento
|
||||
allow_origins=["*"], # Allows all origins in development
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Configuração de arquivos estáticos
|
||||
# Static files configuration
|
||||
static_dir = Path("static")
|
||||
if not static_dir.exists():
|
||||
static_dir.mkdir(parents=True)
|
||||
@ -72,6 +73,7 @@ contact_router = src.api.contact_routes.router
|
||||
mcp_server_router = src.api.mcp_server_routes.router
|
||||
tool_router = src.api.tool_routes.router
|
||||
client_router = src.api.client_routes.router
|
||||
a2a_router = src.api.a2a_routes.router
|
||||
|
||||
# Include routes
|
||||
app.include_router(auth_router, prefix=API_PREFIX)
|
||||
@ -83,6 +85,7 @@ app.include_router(chat_router, prefix=API_PREFIX)
|
||||
app.include_router(session_router, prefix=API_PREFIX)
|
||||
app.include_router(agent_router, prefix=API_PREFIX)
|
||||
app.include_router(contact_router, prefix=API_PREFIX)
|
||||
app.include_router(a2a_router, prefix=API_PREFIX)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
|
9
src/schemas/a2a/__init__.py
Normal file
9
src/schemas/a2a/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
"""
|
||||
A2A (Agent-to-Agent) schema package.
|
||||
|
||||
This package contains Pydantic schema definitions for the A2A protocol.
|
||||
"""
|
||||
|
||||
from src.schemas.a2a.types import *
|
||||
from src.schemas.a2a.exceptions import *
|
||||
from src.schemas.a2a.validators import *
|
147
src/schemas/a2a/exceptions.py
Normal file
147
src/schemas/a2a/exceptions.py
Normal file
@ -0,0 +1,147 @@
|
||||
"""
|
||||
A2A (Agent-to-Agent) protocol exception definitions.
|
||||
|
||||
This module contains error types and exceptions for the A2A protocol.
|
||||
"""
|
||||
|
||||
from src.schemas.a2a.types import JSONRPCError
|
||||
|
||||
|
||||
class JSONParseError(JSONRPCError):
|
||||
"""
|
||||
Error raised when JSON parsing fails.
|
||||
"""
|
||||
|
||||
code: int = -32700
|
||||
message: str = "Invalid JSON payload"
|
||||
data: object | None = None
|
||||
|
||||
|
||||
class InvalidRequestError(JSONRPCError):
|
||||
"""
|
||||
Error raised when request validation fails.
|
||||
"""
|
||||
|
||||
code: int = -32600
|
||||
message: str = "Request payload validation error"
|
||||
data: object | None = None
|
||||
|
||||
|
||||
class MethodNotFoundError(JSONRPCError):
|
||||
"""
|
||||
Error raised when the requested method is not found.
|
||||
"""
|
||||
|
||||
code: int = -32601
|
||||
message: str = "Method not found"
|
||||
data: None = None
|
||||
|
||||
|
||||
class InvalidParamsError(JSONRPCError):
|
||||
"""
|
||||
Error raised when the parameters are invalid.
|
||||
"""
|
||||
|
||||
code: int = -32602
|
||||
message: str = "Invalid parameters"
|
||||
data: object | None = None
|
||||
|
||||
|
||||
class InternalError(JSONRPCError):
|
||||
"""
|
||||
Error raised when an internal error occurs.
|
||||
"""
|
||||
|
||||
code: int = -32603
|
||||
message: str = "Internal error"
|
||||
data: object | None = None
|
||||
|
||||
|
||||
class TaskNotFoundError(JSONRPCError):
|
||||
"""
|
||||
Error raised when the requested task is not found.
|
||||
"""
|
||||
|
||||
code: int = -32001
|
||||
message: str = "Task not found"
|
||||
data: None = None
|
||||
|
||||
|
||||
class TaskNotCancelableError(JSONRPCError):
|
||||
"""
|
||||
Error raised when a task cannot be canceled.
|
||||
"""
|
||||
|
||||
code: int = -32002
|
||||
message: str = "Task cannot be canceled"
|
||||
data: None = None
|
||||
|
||||
|
||||
class PushNotificationNotSupportedError(JSONRPCError):
|
||||
"""
|
||||
Error raised when push notifications are not supported.
|
||||
"""
|
||||
|
||||
code: int = -32003
|
||||
message: str = "Push Notification is not supported"
|
||||
data: None = None
|
||||
|
||||
|
||||
class UnsupportedOperationError(JSONRPCError):
|
||||
"""
|
||||
Error raised when an operation is not supported.
|
||||
"""
|
||||
|
||||
code: int = -32004
|
||||
message: str = "This operation is not supported"
|
||||
data: None = None
|
||||
|
||||
|
||||
class ContentTypeNotSupportedError(JSONRPCError):
|
||||
"""
|
||||
Error raised when content types are incompatible.
|
||||
"""
|
||||
|
||||
code: int = -32005
|
||||
message: str = "Incompatible content types"
|
||||
data: None = None
|
||||
|
||||
|
||||
# Client exceptions
|
||||
|
||||
|
||||
class A2AClientError(Exception):
|
||||
"""
|
||||
Base exception for A2A client errors.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class A2AClientHTTPError(A2AClientError):
|
||||
"""
|
||||
Exception for HTTP errors in A2A client.
|
||||
"""
|
||||
|
||||
def __init__(self, status_code: int, message: str):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(f"HTTP Error {status_code}: {message}")
|
||||
|
||||
|
||||
class A2AClientJSONError(A2AClientError):
|
||||
"""
|
||||
Exception for JSON errors in A2A client.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(f"JSON Error: {message}")
|
||||
|
||||
|
||||
class MissingAPIKeyError(Exception):
|
||||
"""
|
||||
Exception for missing API key.
|
||||
"""
|
||||
|
||||
pass
|
464
src/schemas/a2a/types.py
Normal file
464
src/schemas/a2a/types.py
Normal file
@ -0,0 +1,464 @@
|
||||
"""
|
||||
A2A (Agent-to-Agent) protocol type definitions.
|
||||
|
||||
This module contains Pydantic schema definitions for the A2A protocol.
|
||||
"""
|
||||
|
||||
from typing import Union, Any, List, Optional, Annotated, Literal
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
TypeAdapter,
|
||||
field_serializer,
|
||||
model_validator,
|
||||
ConfigDict,
|
||||
)
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
from enum import Enum
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class TaskState(str, Enum):
|
||||
"""
|
||||
Enum for the state of a task in the A2A protocol.
|
||||
|
||||
States follow the A2A protocol specification.
|
||||
"""
|
||||
|
||||
SUBMITTED = "submitted"
|
||||
WORKING = "working"
|
||||
INPUT_REQUIRED = "input-required"
|
||||
COMPLETED = "completed"
|
||||
CANCELED = "canceled"
|
||||
FAILED = "failed"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class TextPart(BaseModel):
|
||||
"""
|
||||
Represents a text part in a message.
|
||||
"""
|
||||
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class FileContent(BaseModel):
|
||||
"""
|
||||
Represents file content in a file part.
|
||||
|
||||
Either bytes or uri must be provided, but not both.
|
||||
"""
|
||||
|
||||
name: str | None = None
|
||||
mimeType: str | None = None
|
||||
bytes: str | None = None
|
||||
uri: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_content(self) -> Self:
|
||||
"""
|
||||
Validates that either bytes or uri is present, but not both.
|
||||
"""
|
||||
if not (self.bytes or self.uri):
|
||||
raise ValueError("Either 'bytes' or 'uri' must be present in the file data")
|
||||
if self.bytes and self.uri:
|
||||
raise ValueError(
|
||||
"Only one of 'bytes' or 'uri' can be present in the file data"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class FilePart(BaseModel):
|
||||
"""
|
||||
Represents a file part in a message.
|
||||
"""
|
||||
|
||||
type: Literal["file"] = "file"
|
||||
file: FileContent
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class DataPart(BaseModel):
|
||||
"""
|
||||
Represents a data part in a message.
|
||||
"""
|
||||
|
||||
type: Literal["data"] = "data"
|
||||
data: dict[str, Any]
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator="type")]
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""
|
||||
Represents a message in the A2A protocol.
|
||||
|
||||
A message consists of a role and one or more parts.
|
||||
"""
|
||||
|
||||
role: Literal["user", "agent"]
|
||||
parts: List[Part]
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskStatus(BaseModel):
|
||||
"""
|
||||
Represents the status of a task.
|
||||
"""
|
||||
|
||||
state: TaskState
|
||||
message: Message | None = None
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
@field_serializer("timestamp")
|
||||
def serialize_dt(self, dt: datetime, _info):
|
||||
"""
|
||||
Serializes datetime to ISO format.
|
||||
"""
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
class Artifact(BaseModel):
|
||||
"""
|
||||
Represents an artifact produced by an agent.
|
||||
"""
|
||||
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
parts: List[Part]
|
||||
metadata: dict[str, Any] | None = None
|
||||
index: int = 0
|
||||
append: bool | None = None
|
||||
lastChunk: bool | None = None
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
"""
|
||||
Represents a task in the A2A protocol.
|
||||
"""
|
||||
|
||||
id: str
|
||||
sessionId: str | None = None
|
||||
status: TaskStatus
|
||||
artifacts: List[Artifact] | None = None
|
||||
history: List[Message] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskStatusUpdateEvent(BaseModel):
|
||||
"""
|
||||
Represents a task status update event.
|
||||
"""
|
||||
|
||||
id: str
|
||||
status: TaskStatus
|
||||
final: bool = False
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskArtifactUpdateEvent(BaseModel):
|
||||
"""
|
||||
Represents a task artifact update event.
|
||||
"""
|
||||
|
||||
id: str
|
||||
artifact: Artifact
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AuthenticationInfo(BaseModel):
|
||||
"""
|
||||
Represents authentication information for push notifications.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
schemes: List[str]
|
||||
credentials: str | None = None
|
||||
|
||||
|
||||
class PushNotificationConfig(BaseModel):
|
||||
"""
|
||||
Represents push notification configuration.
|
||||
"""
|
||||
|
||||
url: str
|
||||
token: str | None = None
|
||||
authentication: AuthenticationInfo | None = None
|
||||
|
||||
|
||||
class TaskIdParams(BaseModel):
|
||||
"""
|
||||
Represents parameters for identifying a task.
|
||||
"""
|
||||
|
||||
id: str
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskQueryParams(TaskIdParams):
|
||||
"""
|
||||
Represents parameters for querying a task.
|
||||
"""
|
||||
|
||||
historyLength: int | None = None
|
||||
|
||||
|
||||
class TaskSendParams(BaseModel):
|
||||
"""
|
||||
Represents parameters for sending a task.
|
||||
"""
|
||||
|
||||
id: str
|
||||
sessionId: str = Field(default_factory=lambda: uuid4().hex)
|
||||
message: Message
|
||||
acceptedOutputModes: Optional[List[str]] = None
|
||||
pushNotification: PushNotificationConfig | None = None
|
||||
historyLength: int | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
agentId: str = ""
|
||||
|
||||
|
||||
class TaskPushNotificationConfig(BaseModel):
|
||||
"""
|
||||
Represents push notification configuration for a task.
|
||||
"""
|
||||
|
||||
id: str
|
||||
pushNotificationConfig: PushNotificationConfig
|
||||
|
||||
|
||||
# RPC Messages
|
||||
|
||||
|
||||
class JSONRPCMessage(BaseModel):
|
||||
"""
|
||||
Base class for JSON-RPC messages.
|
||||
"""
|
||||
|
||||
jsonrpc: Literal["2.0"] = "2.0"
|
||||
id: int | str | None = Field(default_factory=lambda: uuid4().hex)
|
||||
|
||||
|
||||
class JSONRPCRequest(JSONRPCMessage):
|
||||
"""
|
||||
Represents a JSON-RPC request.
|
||||
"""
|
||||
|
||||
method: str
|
||||
params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class JSONRPCError(BaseModel):
|
||||
"""
|
||||
Represents a JSON-RPC error.
|
||||
"""
|
||||
|
||||
code: int
|
||||
message: str
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class JSONRPCResponse(JSONRPCMessage):
|
||||
"""
|
||||
Represents a JSON-RPC response.
|
||||
"""
|
||||
|
||||
result: Any | None = None
|
||||
error: JSONRPCError | None = None
|
||||
|
||||
|
||||
class SendTaskRequest(JSONRPCRequest):
|
||||
"""
|
||||
Represents a request to send a task.
|
||||
"""
|
||||
|
||||
method: Literal["tasks/send"] = "tasks/send"
|
||||
params: TaskSendParams
|
||||
|
||||
|
||||
class SendTaskResponse(JSONRPCResponse):
|
||||
"""
|
||||
Represents a response to a send task request.
|
||||
"""
|
||||
|
||||
result: Task | None = None
|
||||
|
||||
|
||||
class SendTaskStreamingRequest(JSONRPCRequest):
|
||||
"""
|
||||
Represents a request to send a task with streaming.
|
||||
"""
|
||||
|
||||
method: Literal["tasks/sendSubscribe"] = "tasks/sendSubscribe"
|
||||
params: TaskSendParams
|
||||
|
||||
|
||||
class SendTaskStreamingResponse(JSONRPCResponse):
|
||||
"""
|
||||
Represents a streaming response to a send task request.
|
||||
"""
|
||||
|
||||
result: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None
|
||||
|
||||
|
||||
class GetTaskRequest(JSONRPCRequest):
|
||||
"""
|
||||
Represents a request to get task information.
|
||||
"""
|
||||
|
||||
method: Literal["tasks/get"] = "tasks/get"
|
||||
params: TaskQueryParams
|
||||
|
||||
|
||||
class GetTaskResponse(JSONRPCResponse):
|
||||
"""
|
||||
Represents a response to a get task request.
|
||||
"""
|
||||
|
||||
result: Task | None = None
|
||||
|
||||
|
||||
class CancelTaskRequest(JSONRPCRequest):
|
||||
"""
|
||||
Represents a request to cancel a task.
|
||||
"""
|
||||
|
||||
method: Literal["tasks/cancel",] = "tasks/cancel"
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
class CancelTaskResponse(JSONRPCResponse):
|
||||
"""
|
||||
Represents a response to a cancel task request.
|
||||
"""
|
||||
|
||||
result: Task | None = None
|
||||
|
||||
|
||||
class SetTaskPushNotificationRequest(JSONRPCRequest):
|
||||
"""
|
||||
Represents a request to set push notification for a task.
|
||||
"""
|
||||
|
||||
method: Literal["tasks/pushNotification/set",] = "tasks/pushNotification/set"
|
||||
params: TaskPushNotificationConfig
|
||||
|
||||
|
||||
class SetTaskPushNotificationResponse(JSONRPCResponse):
|
||||
"""
|
||||
Represents a response to a set push notification request.
|
||||
"""
|
||||
|
||||
result: TaskPushNotificationConfig | None = None
|
||||
|
||||
|
||||
class GetTaskPushNotificationRequest(JSONRPCRequest):
|
||||
"""
|
||||
Represents a request to get push notification configuration for a task.
|
||||
"""
|
||||
|
||||
method: Literal["tasks/pushNotification/get",] = "tasks/pushNotification/get"
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
class GetTaskPushNotificationResponse(JSONRPCResponse):
|
||||
"""
|
||||
Represents a response to a get push notification request.
|
||||
"""
|
||||
|
||||
result: TaskPushNotificationConfig | None = None
|
||||
|
||||
|
||||
class TaskResubscriptionRequest(JSONRPCRequest):
|
||||
"""
|
||||
Represents a request to resubscribe to a task.
|
||||
"""
|
||||
|
||||
method: Literal["tasks/resubscribe",] = "tasks/resubscribe"
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
# TypeAdapter for discriminating A2A requests by method
|
||||
A2ARequest = TypeAdapter(
|
||||
Annotated[
|
||||
Union[
|
||||
SendTaskRequest,
|
||||
GetTaskRequest,
|
||||
CancelTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationRequest,
|
||||
TaskResubscriptionRequest,
|
||||
SendTaskStreamingRequest,
|
||||
],
|
||||
Field(discriminator="method"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# Agent Card schemas
|
||||
|
||||
|
||||
class AgentProvider(BaseModel):
|
||||
"""
|
||||
Represents the provider of an agent.
|
||||
"""
|
||||
|
||||
organization: str
|
||||
url: str | None = None
|
||||
|
||||
|
||||
class AgentCapabilities(BaseModel):
|
||||
"""
|
||||
Represents the capabilities of an agent.
|
||||
"""
|
||||
|
||||
streaming: bool = False
|
||||
pushNotifications: bool = False
|
||||
stateTransitionHistory: bool = False
|
||||
|
||||
|
||||
class AgentAuthentication(BaseModel):
|
||||
"""
|
||||
Represents the authentication requirements for an agent.
|
||||
"""
|
||||
|
||||
schemes: List[str]
|
||||
credentials: str | None = None
|
||||
|
||||
|
||||
class AgentSkill(BaseModel):
|
||||
"""
|
||||
Represents a skill of an agent.
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
tags: List[str] | None = None
|
||||
examples: List[str] | None = None
|
||||
inputModes: List[str] | None = None
|
||||
outputModes: List[str] | None = None
|
||||
|
||||
|
||||
class AgentCard(BaseModel):
|
||||
"""
|
||||
Represents an agent card in the A2A protocol.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str | None = None
|
||||
url: str
|
||||
provider: AgentProvider | None = None
|
||||
version: str
|
||||
documentationUrl: str | None = None
|
||||
capabilities: AgentCapabilities
|
||||
authentication: AgentAuthentication | None = None
|
||||
defaultInputModes: List[str] = ["text"]
|
||||
defaultOutputModes: List[str] = ["text"]
|
||||
skills: List[AgentSkill]
|
124
src/schemas/a2a/validators.py
Normal file
124
src/schemas/a2a/validators.py
Normal file
@ -0,0 +1,124 @@
|
||||
"""
|
||||
A2A (Agent-to-Agent) protocol validators.
|
||||
|
||||
This module contains validators for the A2A protocol data.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
import base64
|
||||
import re
|
||||
from pydantic import ValidationError
|
||||
import logging
|
||||
from src.schemas.a2a.types import Part, TextPart, FilePart, DataPart, FileContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_base64(value: str) -> bool:
|
||||
"""
|
||||
Validates if a string is valid base64.
|
||||
|
||||
Args:
|
||||
value: String to validate
|
||||
|
||||
Returns:
|
||||
True if valid base64, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not value:
|
||||
return False
|
||||
|
||||
# Check if the string has base64 characters only
|
||||
pattern = r"^[A-Za-z0-9+/]+={0,2}$"
|
||||
if not re.match(pattern, value):
|
||||
return False
|
||||
|
||||
# Try to decode
|
||||
base64.b64decode(value)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid base64 string: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def validate_file_content(file_content: FileContent) -> bool:
|
||||
"""
|
||||
Validates file content.
|
||||
|
||||
Args:
|
||||
file_content: FileContent to validate
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
if file_content.bytes is not None:
|
||||
return validate_base64(file_content.bytes)
|
||||
elif file_content.uri is not None:
|
||||
# Basic URL validation
|
||||
pattern = r"^https?://.+"
|
||||
return bool(re.match(pattern, file_content.uri))
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid file content: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def validate_message_parts(parts: List[Part]) -> bool:
|
||||
"""
|
||||
Validates all parts in a message.
|
||||
|
||||
Args:
|
||||
parts: List of parts to validate
|
||||
|
||||
Returns:
|
||||
True if all parts are valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
for part in parts:
|
||||
if isinstance(part, TextPart):
|
||||
if not part.text or not isinstance(part.text, str):
|
||||
return False
|
||||
elif isinstance(part, FilePart):
|
||||
if not validate_file_content(part.file):
|
||||
return False
|
||||
elif isinstance(part, DataPart):
|
||||
if not part.data or not isinstance(part.data, dict):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
except (ValidationError, Exception) as e:
|
||||
logger.warning(f"Invalid message parts: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def text_to_parts(text: str) -> List[Part]:
|
||||
"""
|
||||
Converts a plain text to a list of message parts.
|
||||
|
||||
Args:
|
||||
text: Plain text to convert
|
||||
|
||||
Returns:
|
||||
List containing a single TextPart
|
||||
"""
|
||||
return [TextPart(text=text)]
|
||||
|
||||
|
||||
def parts_to_text(parts: List[Part]) -> str:
|
||||
"""
|
||||
Extracts text from a list of message parts.
|
||||
|
||||
Args:
|
||||
parts: List of parts to extract text from
|
||||
|
||||
Returns:
|
||||
Concatenated text from all text parts
|
||||
"""
|
||||
text = ""
|
||||
for part in parts:
|
||||
if isinstance(part, TextPart):
|
||||
text += part.text
|
||||
# Could add handling for other part types here
|
||||
return text
|
@ -1 +1,9 @@
|
||||
from .agent_runner import run_agent
|
||||
from .redis_cache_service import RedisCacheService
|
||||
from .a2a_task_manager_service import A2ATaskManager
|
||||
from .a2a_server_service import A2AServer
|
||||
from .a2a_integration_service import (
|
||||
AgentRunnerAdapter,
|
||||
StreamingServiceAdapter,
|
||||
create_agent_card_from_agent,
|
||||
)
|
||||
|
520
src/services/a2a_integration_service.py
Normal file
520
src/services/a2a_integration_service.py
Normal file
@ -0,0 +1,520 @@
|
||||
"""
|
||||
A2A Integration Service.
|
||||
|
||||
This service provides adapters to integrate existing services with the A2A protocol.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, AsyncIterable
|
||||
|
||||
from src.schemas.a2a import (
|
||||
AgentCard,
|
||||
AgentCapabilities,
|
||||
AgentProvider,
|
||||
Artifact,
|
||||
Message,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskStatus,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentRunnerAdapter:
|
||||
"""
|
||||
Adapter for integrating the existing agent runner with the A2A protocol.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_runner_func,
|
||||
session_service=None,
|
||||
artifacts_service=None,
|
||||
memory_service=None,
|
||||
db=None,
|
||||
):
|
||||
"""
|
||||
Initialize the adapter.
|
||||
|
||||
Args:
|
||||
agent_runner_func: The agent runner function (e.g., run_agent)
|
||||
session_service: Session service for message history
|
||||
artifacts_service: Artifacts service for artifact history
|
||||
memory_service: Memory service for agent memory
|
||||
db: Database session
|
||||
"""
|
||||
self.agent_runner_func = agent_runner_func
|
||||
self.session_service = session_service
|
||||
self.artifacts_service = artifacts_service
|
||||
self.memory_service = memory_service
|
||||
self.db = db
|
||||
|
||||
async def get_supported_modes(self) -> List[str]:
|
||||
"""
|
||||
Get the supported output modes for the agent.
|
||||
|
||||
Returns:
|
||||
List of supported output modes
|
||||
"""
|
||||
# Default modes, can be extended based on agent configuration
|
||||
return ["text", "application/json"]
|
||||
|
||||
async def run_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
message: str,
|
||||
session_id: Optional[str] = None,
|
||||
task_id: Optional[str] = None,
|
||||
db=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the agent with the given message.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to run
|
||||
message: User message to process
|
||||
session_id: Optional session ID for conversation context
|
||||
task_id: Optional task ID for tracking
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary with the agent's response
|
||||
"""
|
||||
logger.info(
|
||||
f"[AGENT-RUNNER] run_agent iniciado - agent_id={agent_id}, task_id={task_id}, session_id={session_id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT-RUNNER] run_agent - message: '{message[:50]}...' (truncado)"
|
||||
)
|
||||
|
||||
try:
|
||||
# Use the existing agent runner function
|
||||
session_id = session_id or str(uuid.uuid4())
|
||||
task_id = task_id or str(uuid.uuid4())
|
||||
|
||||
# Use the provided db or fallback to self.db
|
||||
db_session = db if db is not None else self.db
|
||||
|
||||
if db_session is None:
|
||||
logger.error(
|
||||
f"[AGENT-RUNNER] No database session available. db={db}, self.db={self.db}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[AGENT-RUNNER] Using database session: {type(db_session).__name__}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AGENT-RUNNER] Chamando agent_runner_func com agent_id={agent_id}, contact_id={task_id}"
|
||||
)
|
||||
response_text = await self.agent_runner_func(
|
||||
agent_id=agent_id,
|
||||
contact_id=task_id,
|
||||
message=message,
|
||||
session_service=self.session_service,
|
||||
artifacts_service=self.artifacts_service,
|
||||
memory_service=self.memory_service,
|
||||
db=db_session,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AGENT-RUNNER] run_agent concluído com sucesso para agent_id={agent_id}, task_id={task_id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT-RUNNER] resposta: '{str(response_text)[:50]}...' (truncado)"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"content": response_text,
|
||||
"session_id": session_id,
|
||||
"task_id": task_id,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AGENT-RUNNER] Error running agent: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"session_id": session_id,
|
||||
"task_id": task_id,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
async def cancel_task(self, task_id: str) -> bool:
|
||||
"""
|
||||
Cancel a running task.
|
||||
|
||||
Args:
|
||||
task_id: ID of the task to cancel
|
||||
|
||||
Returns:
|
||||
True if successfully canceled, False otherwise
|
||||
"""
|
||||
# Currently, the agent runner doesn't support cancellation
|
||||
# This is a placeholder for future implementation
|
||||
logger.warning(f"Task cancellation not implemented for task {task_id}")
|
||||
return False
|
||||
|
||||
|
||||
class StreamingServiceAdapter:
|
||||
"""
|
||||
Adapter for integrating the existing streaming service with the A2A protocol.
|
||||
"""
|
||||
|
||||
def __init__(self, streaming_service):
|
||||
"""
|
||||
Initialize the adapter.
|
||||
|
||||
Args:
|
||||
streaming_service: The streaming service instance
|
||||
"""
|
||||
self.streaming_service = streaming_service
|
||||
|
||||
async def stream_agent_response(
|
||||
self,
|
||||
agent_id: str,
|
||||
message: str,
|
||||
api_key: str,
|
||||
session_id: Optional[str] = None,
|
||||
task_id: Optional[str] = None,
|
||||
db=None,
|
||||
) -> AsyncIterable[str]:
|
||||
"""
|
||||
Stream the agent's response as A2A events.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
message: User message to process
|
||||
api_key: API key for authentication
|
||||
session_id: Optional session ID for conversation context
|
||||
task_id: Optional task ID for tracking
|
||||
db: Database session
|
||||
|
||||
Yields:
|
||||
A2A event objects as JSON strings for SSE (Server-Sent Events)
|
||||
"""
|
||||
task_id = task_id or str(uuid.uuid4())
|
||||
logger.info(f"Starting streaming response for task {task_id}")
|
||||
|
||||
# Set working status event
|
||||
working_status = TaskStatus(
|
||||
state="working",
|
||||
timestamp=datetime.now(),
|
||||
message=Message(
|
||||
role="agent", parts=[TextPart(text="Processing your request...")]
|
||||
),
|
||||
)
|
||||
|
||||
status_event = TaskStatusUpdateEvent(
|
||||
id=task_id, status=working_status, final=False
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
yield json.dumps(status_event.model_dump())
|
||||
|
||||
content_buffer = ""
|
||||
final_sent = False
|
||||
has_error = False
|
||||
|
||||
# Stream from the existing streaming service
|
||||
try:
|
||||
logger.info(f"Setting up streaming for agent {agent_id}, task {task_id}")
|
||||
# To streaming, we use task_id as contact_id
|
||||
contact_id = task_id
|
||||
|
||||
# Adicionar tratamento de heartbeat para manter conexão ativa
|
||||
last_event_time = datetime.now()
|
||||
heartbeat_interval = 20 # segundos
|
||||
|
||||
async for event in self.streaming_service.send_task_streaming(
|
||||
agent_id=agent_id,
|
||||
api_key=api_key,
|
||||
message=message,
|
||||
contact_id=contact_id,
|
||||
session_id=session_id,
|
||||
db=db,
|
||||
):
|
||||
# Atualizar timestamp do último evento
|
||||
last_event_time = datetime.now()
|
||||
|
||||
# Process the streaming event format
|
||||
event_data = event.get("data", "{}")
|
||||
try:
|
||||
logger.info(f"Processing event data: {event_data[:100]}...")
|
||||
data = json.loads(event_data)
|
||||
|
||||
# Extract content
|
||||
if "delta" in data and data["delta"].get("content"):
|
||||
content = data["delta"]["content"]
|
||||
content_buffer += content
|
||||
logger.info(f"Received content chunk: {content[:50]}...")
|
||||
|
||||
# Create artifact update event
|
||||
artifact = Artifact(
|
||||
name="response",
|
||||
parts=[TextPart(text=content)],
|
||||
index=0,
|
||||
append=True,
|
||||
lastChunk=False,
|
||||
)
|
||||
|
||||
artifact_event = TaskArtifactUpdateEvent(
|
||||
id=task_id, artifact=artifact
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
yield json.dumps(artifact_event.model_dump())
|
||||
|
||||
# Check if final event
|
||||
if data.get("done", False) and not final_sent:
|
||||
logger.info(f"Received final event for task {task_id}")
|
||||
# Create completed status event
|
||||
completed_status = TaskStatus(
|
||||
state="completed",
|
||||
timestamp=datetime.now(),
|
||||
message=Message(
|
||||
role="agent",
|
||||
parts=[
|
||||
TextPart(text=content_buffer or "Task completed")
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Final artifact with full content
|
||||
final_artifact = Artifact(
|
||||
name="response",
|
||||
parts=[TextPart(text=content_buffer)],
|
||||
index=0,
|
||||
append=False,
|
||||
lastChunk=True,
|
||||
)
|
||||
|
||||
# Send the final artifact
|
||||
final_artifact_event = TaskArtifactUpdateEvent(
|
||||
id=task_id, artifact=final_artifact
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
yield json.dumps(final_artifact_event.model_dump())
|
||||
|
||||
# Send the completed status
|
||||
final_status_event = TaskStatusUpdateEvent(
|
||||
id=task_id,
|
||||
status=completed_status,
|
||||
final=True,
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
yield json.dumps(final_status_event.model_dump())
|
||||
|
||||
final_sent = True
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"Received non-JSON event data: {e}. Data: {event_data[:100]}..."
|
||||
)
|
||||
# Handle non-JSON events - simply add to buffer as text
|
||||
if isinstance(event_data, str):
|
||||
content_buffer += event_data
|
||||
|
||||
# Create artifact update event
|
||||
artifact = Artifact(
|
||||
name="response",
|
||||
parts=[TextPart(text=event_data)],
|
||||
index=0,
|
||||
append=True,
|
||||
lastChunk=False,
|
||||
)
|
||||
|
||||
artifact_event = TaskArtifactUpdateEvent(
|
||||
id=task_id, artifact=artifact
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
yield json.dumps(artifact_event.model_dump())
|
||||
elif isinstance(event_data, dict):
|
||||
# Try to extract text from the dictionary
|
||||
text_value = str(event_data)
|
||||
content_buffer += text_value
|
||||
|
||||
artifact = Artifact(
|
||||
name="response",
|
||||
parts=[TextPart(text=text_value)],
|
||||
index=0,
|
||||
append=True,
|
||||
lastChunk=False,
|
||||
)
|
||||
|
||||
artifact_event = TaskArtifactUpdateEvent(
|
||||
id=task_id, artifact=artifact
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
yield json.dumps(artifact_event.model_dump())
|
||||
|
||||
# Enviar heartbeat/keep-alive para manter a conexão SSE aberta
|
||||
now = datetime.now()
|
||||
if (now - last_event_time).total_seconds() > heartbeat_interval:
|
||||
logger.info(f"Sending heartbeat for task {task_id}")
|
||||
# Enviando evento de keep-alive como um evento de status de "working"
|
||||
working_heartbeat = TaskStatus(
|
||||
state="working",
|
||||
timestamp=now,
|
||||
message=Message(
|
||||
role="agent", parts=[TextPart(text="Still processing...")]
|
||||
),
|
||||
)
|
||||
heartbeat_event = TaskStatusUpdateEvent(
|
||||
id=task_id, status=working_heartbeat, final=False
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
yield json.dumps(heartbeat_event.model_dump())
|
||||
last_event_time = now
|
||||
|
||||
# Ensure we send a final event if not already sent
|
||||
if not final_sent:
|
||||
logger.info(
|
||||
f"Stream completed for task {task_id}, sending final status"
|
||||
)
|
||||
# Create completed status event
|
||||
completed_status = TaskStatus(
|
||||
state="completed",
|
||||
timestamp=datetime.now(),
|
||||
message=Message(
|
||||
role="agent",
|
||||
parts=[TextPart(text=content_buffer or "Task completed")],
|
||||
),
|
||||
)
|
||||
|
||||
# Send the completed status
|
||||
final_event = TaskStatusUpdateEvent(
|
||||
id=task_id, status=completed_status, final=True
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
yield json.dumps(final_event.model_dump())
|
||||
|
||||
except Exception as e:
|
||||
has_error = True
|
||||
logger.error(f"Error in streaming for task {task_id}: {e}", exc_info=True)
|
||||
|
||||
# Create failed status event
|
||||
failed_status = TaskStatus(
|
||||
state="failed",
|
||||
timestamp=datetime.now(),
|
||||
message=Message(
|
||||
role="agent",
|
||||
parts=[
|
||||
TextPart(
|
||||
text=f"Error during streaming: {str(e)}. Partial response: {content_buffer[:200] if content_buffer else 'No content received'}"
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
error_event = TaskStatusUpdateEvent(
|
||||
id=task_id, status=failed_status, final=True
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
yield json.dumps(error_event.model_dump())
|
||||
|
||||
finally:
|
||||
# Garantir que enviamos um evento final para fechar a conexão corretamente
|
||||
if not final_sent and not has_error:
|
||||
logger.info(f"Stream finalizing for task {task_id} via finally block")
|
||||
try:
|
||||
# Create completed status event
|
||||
completed_status = TaskStatus(
|
||||
state="completed",
|
||||
timestamp=datetime.now(),
|
||||
message=Message(
|
||||
role="agent",
|
||||
parts=[
|
||||
TextPart(
|
||||
text=content_buffer or "Task completed (forced end)"
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Send the completed status
|
||||
final_event = TaskStatusUpdateEvent(
|
||||
id=task_id, status=completed_status, final=True
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
yield json.dumps(final_event.model_dump())
|
||||
except Exception as final_error:
|
||||
logger.error(
|
||||
f"Error sending final event in finally block: {final_error}"
|
||||
)
|
||||
|
||||
logger.info(f"Streaming completed for task {task_id}")
|
||||
|
||||
|
||||
def create_agent_card_from_agent(agent, db) -> AgentCard:
|
||||
"""
|
||||
Create an A2A agent card from an agent model.
|
||||
|
||||
Args:
|
||||
agent: The agent model from the database
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
A2A AgentCard object
|
||||
"""
|
||||
import os
|
||||
from src.api.agent_routes import format_agent_tools
|
||||
import asyncio
|
||||
|
||||
# Extract agent configuration
|
||||
agent_config = agent.config
|
||||
has_streaming = True # Assuming streaming is always supported
|
||||
has_push = True # Assuming push notifications are supported
|
||||
|
||||
# Format tools as skills
|
||||
try:
|
||||
# We use a different approach to handle the asynchronous function
|
||||
mcp_servers = agent_config.get("mcp_servers", [])
|
||||
|
||||
# We create a new thread to execute the asynchronous function
|
||||
import concurrent.futures
|
||||
import functools
|
||||
|
||||
def run_async(coro):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
result = loop.run_until_complete(coro)
|
||||
loop.close()
|
||||
return result
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_async, format_agent_tools(mcp_servers, db))
|
||||
skills = future.result()
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting agent tools: {e}")
|
||||
skills = []
|
||||
|
||||
# Create agent card
|
||||
return AgentCard(
|
||||
name=agent.name,
|
||||
description=agent.description,
|
||||
url=f"{os.getenv('API_URL', '')}/api/v1/a2a/{agent.id}",
|
||||
provider=AgentProvider(
|
||||
organization=os.getenv("ORGANIZATION_NAME", ""),
|
||||
url=os.getenv("ORGANIZATION_URL", ""),
|
||||
),
|
||||
version=os.getenv("API_VERSION", "1.0.0"),
|
||||
capabilities=AgentCapabilities(
|
||||
streaming=has_streaming,
|
||||
pushNotifications=has_push,
|
||||
stateTransitionHistory=True,
|
||||
),
|
||||
authentication={
|
||||
"schemes": ["apiKey"],
|
||||
"credentials": "x-api-key",
|
||||
},
|
||||
defaultInputModes=["text", "application/json"],
|
||||
defaultOutputModes=["text", "application/json"],
|
||||
skills=skills,
|
||||
)
|
739
src/services/a2a_server_service.py
Normal file
739
src/services/a2a_server_service.py
Normal file
@ -0,0 +1,739 @@
|
||||
"""
|
||||
Server A2A and task manager for the A2A protocol.
|
||||
|
||||
This module implements a JSON-RPC compatible server for the A2A protocol,
|
||||
that manages agent tasks, streaming events and push notifications.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Union,
|
||||
AsyncIterable,
|
||||
)
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse, Response
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.schemas.a2a.types import A2ARequest
|
||||
from src.services.agent_runner import run_agent
|
||||
from src.services.a2a_integration_service import (
|
||||
AgentRunnerAdapter,
|
||||
StreamingServiceAdapter,
|
||||
)
|
||||
from src.services.session_service import get_session_events
|
||||
from src.services.redis_cache_service import RedisCacheService
|
||||
from src.schemas.a2a.types import (
|
||||
SendTaskRequest,
|
||||
SendTaskStreamingRequest,
|
||||
GetTaskRequest,
|
||||
CancelTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationRequest,
|
||||
TaskResubscriptionRequest,
|
||||
TaskSendParams,
|
||||
)
|
||||
from src.utils.a2a_utils import are_modalities_compatible
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class A2ATaskManager:
|
||||
"""
|
||||
Task manager for the A2A protocol.
|
||||
|
||||
This class manages the lifecycle of tasks, including:
|
||||
- Task execution
|
||||
- Streaming of events
|
||||
- Push notifications
|
||||
- Status querying
|
||||
- Cancellation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_cache: RedisCacheService,
|
||||
agent_runner: AgentRunnerAdapter,
|
||||
streaming_service: StreamingServiceAdapter,
|
||||
push_notification_service: Any = None,
|
||||
):
|
||||
"""
|
||||
Initialize the task manager.
|
||||
|
||||
Args:
|
||||
redis_cache: Cache service for storing task data
|
||||
agent_runner: Adapter for agent execution
|
||||
streaming_service: Adapter for event streaming
|
||||
push_notification_service: Service for sending push notifications
|
||||
"""
|
||||
self.cache = redis_cache
|
||||
self.agent_runner = agent_runner
|
||||
self.streaming_service = streaming_service
|
||||
self.push_notification_service = push_notification_service
|
||||
self._running_tasks = {}
|
||||
|
||||
async def on_send_task(
|
||||
self,
|
||||
task_id: str,
|
||||
agent_id: str,
|
||||
message: Dict[str, Any],
|
||||
session_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
input_mode: str = "text",
|
||||
output_modes: List[str] = ["text"],
|
||||
db: Optional[Session] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a request to send a task.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
agent_id: Agent ID
|
||||
message: User message
|
||||
session_id: Session ID (optional)
|
||||
metadata: Additional metadata (optional)
|
||||
input_mode: Input mode (text, JSON, etc.)
|
||||
output_modes: Supported output modes
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Response with task result
|
||||
"""
|
||||
if not session_id:
|
||||
session_id = f"{task_id}_{agent_id}"
|
||||
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
|
||||
# Update status to "submitted"
|
||||
task_data = {
|
||||
"id": task_id,
|
||||
"sessionId": session_id,
|
||||
"status": {
|
||||
"state": "submitted",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": None,
|
||||
"error": None,
|
||||
},
|
||||
"artifacts": [],
|
||||
"history": [],
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
# Store initial task data
|
||||
await self.cache.set(f"task:{task_id}", task_data)
|
||||
|
||||
# Check for push notification configurations
|
||||
push_config = await self.cache.get(f"task:{task_id}:push")
|
||||
if push_config and self.push_notification_service:
|
||||
# Send initial notification
|
||||
await self.push_notification_service.send_notification(
|
||||
url=push_config["url"],
|
||||
task_id=task_id,
|
||||
state="submitted",
|
||||
headers=push_config.get("headers", {}),
|
||||
)
|
||||
|
||||
try:
|
||||
# Update status to "running"
|
||||
task_data["status"].update(
|
||||
{"state": "running", "timestamp": datetime.now().isoformat()}
|
||||
)
|
||||
await self.cache.set(f"task:{task_id}", task_data)
|
||||
|
||||
# Notify "running" state
|
||||
if push_config and self.push_notification_service:
|
||||
await self.push_notification_service.send_notification(
|
||||
url=push_config["url"],
|
||||
task_id=task_id,
|
||||
state="running",
|
||||
headers=push_config.get("headers", {}),
|
||||
)
|
||||
|
||||
# Extract user message
|
||||
user_message = None
|
||||
try:
|
||||
user_message = message["parts"][0]["text"]
|
||||
except (KeyError, IndexError):
|
||||
user_message = ""
|
||||
|
||||
# Execute the agent
|
||||
response = await self.agent_runner.run_agent(
|
||||
agent_id=agent_id,
|
||||
task_id=task_id,
|
||||
message=user_message,
|
||||
session_id=session_id,
|
||||
db=db,
|
||||
)
|
||||
|
||||
# Check if the response is a dictionary (error) or a string (success)
|
||||
if isinstance(response, dict) and response.get("status") == "error":
|
||||
# Error response
|
||||
final_response = f"Error: {response.get('error', 'Unknown error')}"
|
||||
|
||||
# Update status to "failed"
|
||||
task_data["status"].update(
|
||||
{
|
||||
"state": "failed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"error": {
|
||||
"code": "AGENT_EXECUTION_ERROR",
|
||||
"message": response.get("error", "Unknown error"),
|
||||
},
|
||||
"message": {
|
||||
"role": "system",
|
||||
"parts": [{"type": "text", "text": final_response}],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Notify "failed" state
|
||||
if push_config and self.push_notification_service:
|
||||
await self.push_notification_service.send_notification(
|
||||
url=push_config["url"],
|
||||
task_id=task_id,
|
||||
state="failed",
|
||||
message={
|
||||
"role": "system",
|
||||
"parts": [{"type": "text", "text": final_response}],
|
||||
},
|
||||
headers=push_config.get("headers", {}),
|
||||
)
|
||||
else:
|
||||
# Success response
|
||||
final_response = (
|
||||
response.get("content") if isinstance(response, dict) else response
|
||||
)
|
||||
|
||||
# Update status to "completed"
|
||||
task_data["status"].update(
|
||||
{
|
||||
"state": "completed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": {
|
||||
"role": "agent",
|
||||
"parts": [{"type": "text", "text": final_response}],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Add artifacts
|
||||
if final_response:
|
||||
task_data["artifacts"].append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": final_response,
|
||||
"metadata": {
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"content_type": "text/plain",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Add history of messages
|
||||
history_length = metadata.get("historyLength", 50)
|
||||
try:
|
||||
history_messages = get_session_events(
|
||||
self.agent_runner.session_service, session_id
|
||||
)
|
||||
history_messages = history_messages[-history_length:]
|
||||
|
||||
formatted_history = []
|
||||
for event in history_messages:
|
||||
if event.content and event.content.parts:
|
||||
role = (
|
||||
"agent"
|
||||
if event.content.role == "model"
|
||||
else event.content.role
|
||||
)
|
||||
formatted_history.append(
|
||||
{
|
||||
"role": role,
|
||||
"parts": [
|
||||
{"type": "text", "text": part.text}
|
||||
for part in event.content.parts
|
||||
if part.text
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
task_data["history"] = formatted_history
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing history: {str(e)}")
|
||||
|
||||
# Notify "completed" state
|
||||
if push_config and self.push_notification_service:
|
||||
await self.push_notification_service.send_notification(
|
||||
url=push_config["url"],
|
||||
task_id=task_id,
|
||||
state="completed",
|
||||
message={
|
||||
"role": "agent",
|
||||
"parts": [{"type": "text", "text": final_response}],
|
||||
},
|
||||
headers=push_config.get("headers", {}),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing task {task_id}: {str(e)}")
|
||||
|
||||
# Update status to "failed"
|
||||
task_data["status"].update(
|
||||
{
|
||||
"state": "failed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"error": {"code": "AGENT_EXECUTION_ERROR", "message": str(e)},
|
||||
}
|
||||
)
|
||||
|
||||
# Notify "failed" state
|
||||
if push_config and self.push_notification_service:
|
||||
await self.push_notification_service.send_notification(
|
||||
url=push_config["url"],
|
||||
task_id=task_id,
|
||||
state="failed",
|
||||
message={
|
||||
"role": "system",
|
||||
"parts": [{"type": "text", "text": str(e)}],
|
||||
},
|
||||
headers=push_config.get("headers", {}),
|
||||
)
|
||||
|
||||
# Store final result
|
||||
await self.cache.set(f"task:{task_id}", task_data)
|
||||
return task_data
|
||||
|
||||
async def on_send_task_subscribe(
|
||||
self,
|
||||
task_id: str,
|
||||
agent_id: str,
|
||||
message: Dict[str, Any],
|
||||
session_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
input_mode: str = "text",
|
||||
output_modes: List[str] = ["text"],
|
||||
db: Optional[Session] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Process a request to send a task with streaming.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
agent_id: Agent ID
|
||||
message: User message
|
||||
session_id: Session ID (optional)
|
||||
metadata: Additional metadata (optional)
|
||||
input_mode: Input mode (text, JSON, etc.)
|
||||
output_modes: Supported output modes
|
||||
db: Database session
|
||||
|
||||
Yields:
|
||||
Streaming events in SSE (Server-Sent Events) format
|
||||
"""
|
||||
if not session_id:
|
||||
session_id = f"{task_id}_{agent_id}"
|
||||
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
|
||||
# Extract user message
|
||||
user_message = ""
|
||||
try:
|
||||
user_message = message["parts"][0]["text"]
|
||||
except (KeyError, IndexError):
|
||||
pass
|
||||
|
||||
# Generate streaming events
|
||||
async for event in self.streaming_service.stream_response(
|
||||
agent_id=agent_id,
|
||||
task_id=task_id,
|
||||
message=user_message,
|
||||
session_id=session_id,
|
||||
db=db,
|
||||
):
|
||||
yield event
|
||||
|
||||
async def on_get_task(self, task_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Query the status of a task by ID.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
Current task status
|
||||
|
||||
Raises:
|
||||
Exception: If the task is not found
|
||||
"""
|
||||
task_data = await self.cache.get(f"task:{task_id}")
|
||||
if not task_data:
|
||||
raise Exception(f"Task {task_id} not found")
|
||||
return task_data
|
||||
|
||||
async def on_cancel_task(self, task_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Cancel a running task.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to be cancelled
|
||||
|
||||
Returns:
|
||||
Task status after cancellation
|
||||
|
||||
Raises:
|
||||
Exception: If the task is not found or cannot be cancelled
|
||||
"""
|
||||
task_data = await self.cache.get(f"task:{task_id}")
|
||||
if not task_data:
|
||||
raise Exception(f"Task {task_id} not found")
|
||||
|
||||
# Check if the task is in a state that can be cancelled
|
||||
current_state = task_data["status"]["state"]
|
||||
if current_state not in ["submitted", "running"]:
|
||||
raise Exception(f"Cannot cancel task in {current_state} state")
|
||||
|
||||
# Cancel the task in the runner if it is running
|
||||
running_task = self._running_tasks.get(task_id)
|
||||
if running_task:
|
||||
# Try to cancel the running task
|
||||
if hasattr(running_task, "cancel"):
|
||||
running_task.cancel()
|
||||
|
||||
# Update status to "cancelled"
|
||||
task_data["status"].update(
|
||||
{
|
||||
"state": "cancelled",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": {
|
||||
"role": "system",
|
||||
"parts": [{"type": "text", "text": "Task cancelled by user"}],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Update cache
|
||||
await self.cache.set(f"task:{task_id}", task_data)
|
||||
|
||||
# Send push notification if configured
|
||||
push_config = await self.cache.get(f"task:{task_id}:push")
|
||||
if push_config and self.push_notification_service:
|
||||
await self.push_notification_service.send_notification(
|
||||
url=push_config["url"],
|
||||
task_id=task_id,
|
||||
state="cancelled",
|
||||
message={
|
||||
"role": "system",
|
||||
"parts": [{"type": "text", "text": "Task cancelled by user"}],
|
||||
},
|
||||
headers=push_config.get("headers", {}),
|
||||
)
|
||||
|
||||
return task_data
|
||||
|
||||
async def on_set_task_push_notification(
|
||||
self, task_id: str, notification_config: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Configure push notifications for a task.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
notification_config: Notification configuration (URL and headers)
|
||||
|
||||
Returns:
|
||||
Updated configuration
|
||||
"""
|
||||
# Validate configuration
|
||||
url = notification_config.get("url")
|
||||
if not url:
|
||||
raise ValueError("Push notification URL is required")
|
||||
|
||||
headers = notification_config.get("headers", {})
|
||||
|
||||
# Store configuration
|
||||
config = {"url": url, "headers": headers}
|
||||
await self.cache.set(f"task:{task_id}:push", config)
|
||||
|
||||
return config
|
||||
|
||||
async def on_get_task_push_notification(self, task_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the push notification configuration for a task.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
Push notification configuration
|
||||
|
||||
Raises:
|
||||
Exception: If there is no configuration for the task
|
||||
"""
|
||||
config = await self.cache.get(f"task:{task_id}:push")
|
||||
if not config:
|
||||
raise Exception(f"No push notification configuration for task {task_id}")
|
||||
return config
|
||||
|
||||
|
||||
class A2AServer:
|
||||
"""
|
||||
A2A server compatible with JSON-RPC 2.0.
|
||||
|
||||
This class processes JSON-RPC requests and forwards them to
|
||||
the appropriate handlers in the A2ATaskManager.
|
||||
"""
|
||||
|
||||
def __init__(self, task_manager: A2ATaskManager, agent_card=None):
|
||||
"""
|
||||
Initialize the A2A server.
|
||||
|
||||
Args:
|
||||
task_manager: Task manager
|
||||
agent_card: Agent card information
|
||||
"""
|
||||
self.task_manager = task_manager
|
||||
self.agent_card = agent_card
|
||||
|
||||
async def process_request(
|
||||
self,
|
||||
request: Request,
|
||||
agent_id: Optional[str] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> Union[Response, JSONResponse, StreamingResponse]:
|
||||
"""
|
||||
Process a JSON-RPC request.
|
||||
|
||||
Args:
|
||||
request: HTTP request
|
||||
agent_id: Optional agent ID to inject into the request
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Appropriate response (JSON or Streaming)
|
||||
"""
|
||||
try:
|
||||
# Try to parse the JSON payload
|
||||
try:
|
||||
logger.info("Starting JSON-RPC request processing")
|
||||
body = await request.json()
|
||||
logger.info(f"Received JSON data: {json.dumps(body)}")
|
||||
method = body.get("method", "unknown")
|
||||
logger.info(f"[SERVER] Processando método: {method}")
|
||||
|
||||
# Validate the request using the A2A validator
|
||||
json_rpc_request = A2ARequest.validate_python(body)
|
||||
logger.info(
|
||||
f"[SERVER] Request validado como: {type(json_rpc_request).__name__}"
|
||||
)
|
||||
|
||||
original_db = self.task_manager.db
|
||||
try:
|
||||
# Set the db temporarily
|
||||
if db is not None:
|
||||
self.task_manager.db = db
|
||||
|
||||
# Process the request
|
||||
if isinstance(json_rpc_request, SendTaskRequest):
|
||||
logger.info(
|
||||
f"[SERVER] Processando SendTaskRequest para task_id={json_rpc_request.params.id}"
|
||||
)
|
||||
json_rpc_request.params.agentId = agent_id
|
||||
result = await self.task_manager.on_send_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, SendTaskStreamingRequest):
|
||||
logger.info(
|
||||
f"[SERVER] Processando SendTaskStreamingRequest para task_id={json_rpc_request.params.id}"
|
||||
)
|
||||
json_rpc_request.params.agentId = agent_id
|
||||
result = await self.task_manager.on_send_task_subscribe(
|
||||
json_rpc_request
|
||||
)
|
||||
elif isinstance(json_rpc_request, GetTaskRequest):
|
||||
logger.info(
|
||||
f"[SERVER] Processando GetTaskRequest para task_id={json_rpc_request.params.id}"
|
||||
)
|
||||
result = await self.task_manager.on_get_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, CancelTaskRequest):
|
||||
logger.info(
|
||||
f"[SERVER] Processando CancelTaskRequest para task_id={json_rpc_request.params.id}"
|
||||
)
|
||||
result = await self.task_manager.on_cancel_task(
|
||||
json_rpc_request
|
||||
)
|
||||
elif isinstance(json_rpc_request, SetTaskPushNotificationRequest):
|
||||
logger.info(
|
||||
f"[SERVER] Processando SetTaskPushNotificationRequest para task_id={json_rpc_request.params.id}"
|
||||
)
|
||||
result = await self.task_manager.on_set_task_push_notification(
|
||||
json_rpc_request
|
||||
)
|
||||
elif isinstance(json_rpc_request, GetTaskPushNotificationRequest):
|
||||
logger.info(
|
||||
f"[SERVER] Processando GetTaskPushNotificationRequest para task_id={json_rpc_request.params.id}"
|
||||
)
|
||||
result = await self.task_manager.on_get_task_push_notification(
|
||||
json_rpc_request
|
||||
)
|
||||
elif isinstance(json_rpc_request, TaskResubscriptionRequest):
|
||||
logger.info(
|
||||
f"[SERVER] Processando TaskResubscriptionRequest para task_id={json_rpc_request.params.id}"
|
||||
)
|
||||
result = await self.task_manager.on_resubscribe_to_task(
|
||||
json_rpc_request
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SERVER] Tipo de request não suportado: {type(json_rpc_request)}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": body.get("id"),
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": "Method not found",
|
||||
"data": {"detail": f"Method not supported"},
|
||||
},
|
||||
},
|
||||
)
|
||||
finally:
|
||||
# Restore the original db
|
||||
self.task_manager.db = original_db
|
||||
|
||||
# Create appropriate response
|
||||
return self._create_response(result)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
# Error parsing JSON
|
||||
logger.error(f"Error parsing JSON request: {str(e)}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": None,
|
||||
"error": {
|
||||
"code": -32700,
|
||||
"message": "Parse error",
|
||||
"data": {"detail": str(e)},
|
||||
},
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
# Other validation errors
|
||||
logger.error(f"Error validating request: {str(e)}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": body.get("id") if "body" in locals() else None,
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": "Invalid Request",
|
||||
"data": {"detail": str(e)},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing JSON-RPC request: {str(e)}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": None,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": "Internal error",
|
||||
"data": {"detail": str(e)},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def _create_response(self, result: Any) -> Union[JSONResponse, StreamingResponse]:
|
||||
"""
|
||||
Create appropriate response based on result type.
|
||||
|
||||
Args:
|
||||
result: Result from task manager
|
||||
|
||||
Returns:
|
||||
JSON or streaming response
|
||||
"""
|
||||
if isinstance(result, AsyncIterable):
|
||||
# Result in streaming (SSE)
|
||||
async def event_generator():
|
||||
async for item in result:
|
||||
if hasattr(item, "model_dump_json"):
|
||||
yield {"data": item.model_dump_json(exclude_none=True)}
|
||||
else:
|
||||
yield {"data": json.dumps(item)}
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
elif hasattr(result, "model_dump"):
|
||||
# Result is a Pydantic object
|
||||
return JSONResponse(result.model_dump(exclude_none=True))
|
||||
|
||||
else:
|
||||
# Result is a dictionary or other simple type
|
||||
return JSONResponse(result)
|
||||
|
||||
async def get_agent_card(
|
||||
self, request: Request, db: Optional[Session] = None
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Get the agent card.
|
||||
|
||||
Args:
|
||||
request: HTTP request
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Agent card as JSON
|
||||
"""
|
||||
if not self.agent_card:
|
||||
logger.error("Agent card not configured")
|
||||
return JSONResponse(
|
||||
status_code=404, content={"error": "Agent card not configured"}
|
||||
)
|
||||
|
||||
# If there is db, set it temporarily in the task_manager
|
||||
if db and hasattr(self.task_manager, "db"):
|
||||
original_db = self.task_manager.db
|
||||
try:
|
||||
self.task_manager.db = db
|
||||
|
||||
# If it's a Pydantic object, convert to dictionary
|
||||
if hasattr(self.agent_card, "model_dump"):
|
||||
return JSONResponse(self.agent_card.model_dump(exclude_none=True))
|
||||
else:
|
||||
return JSONResponse(self.agent_card)
|
||||
finally:
|
||||
# Restore the original db
|
||||
self.task_manager.db = original_db
|
||||
else:
|
||||
# If it's a Pydantic object, convert to dictionary
|
||||
if hasattr(self.agent_card, "model_dump"):
|
||||
return JSONResponse(self.agent_card.model_dump(exclude_none=True))
|
||||
else:
|
||||
return JSONResponse(self.agent_card)
|
888
src/services/a2a_task_manager_service.py
Normal file
888
src/services/a2a_task_manager_service.py
Normal file
@ -0,0 +1,888 @@
|
||||
"""
|
||||
A2A Task Manager Service.
|
||||
|
||||
This service implements task management for the A2A protocol, handling task lifecycle
|
||||
including execution, streaming, push notifications, status queries, and cancellation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union, AsyncIterable
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.schemas.a2a.exceptions import (
|
||||
TaskNotFoundError,
|
||||
TaskNotCancelableError,
|
||||
PushNotificationNotSupportedError,
|
||||
InternalError,
|
||||
ContentTypeNotSupportedError,
|
||||
)
|
||||
|
||||
from src.schemas.a2a.types import (
|
||||
JSONRPCResponse,
|
||||
TaskIdParams,
|
||||
TaskQueryParams,
|
||||
GetTaskRequest,
|
||||
SendTaskRequest,
|
||||
CancelTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationRequest,
|
||||
GetTaskResponse,
|
||||
CancelTaskResponse,
|
||||
SendTaskResponse,
|
||||
SetTaskPushNotificationResponse,
|
||||
GetTaskPushNotificationResponse,
|
||||
TaskSendParams,
|
||||
TaskStatus,
|
||||
TaskState,
|
||||
TaskResubscriptionRequest,
|
||||
SendTaskStreamingRequest,
|
||||
SendTaskStreamingResponse,
|
||||
Artifact,
|
||||
PushNotificationConfig,
|
||||
TaskStatusUpdateEvent,
|
||||
TaskArtifactUpdateEvent,
|
||||
JSONRPCError,
|
||||
TaskPushNotificationConfig,
|
||||
Message,
|
||||
TextPart,
|
||||
Task,
|
||||
)
|
||||
from src.services.redis_cache_service import RedisCacheService
|
||||
from src.utils.a2a_utils import (
|
||||
are_modalities_compatible,
|
||||
new_incompatible_types_error,
|
||||
new_not_implemented_error,
|
||||
create_task_id,
|
||||
format_error_response,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class A2ATaskManager:
|
||||
"""
|
||||
A2A Task Manager implementation.
|
||||
|
||||
This class manages the lifecycle of A2A tasks, including:
|
||||
- Task submission and execution
|
||||
- Task status queries
|
||||
- Task cancellation
|
||||
- Push notification configuration
|
||||
- SSE streaming for real-time updates
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_cache: RedisCacheService,
|
||||
agent_runner=None,
|
||||
streaming_service=None,
|
||||
push_notification_service=None,
|
||||
db=None,
|
||||
):
|
||||
"""
|
||||
Initialize the A2A Task Manager.
|
||||
|
||||
Args:
|
||||
redis_cache: Redis cache service for task storage
|
||||
agent_runner: Agent runner service for task execution
|
||||
streaming_service: Streaming service for SSE
|
||||
push_notification_service: Service for push notifications
|
||||
db: Database session
|
||||
"""
|
||||
self.redis_cache = redis_cache
|
||||
self.agent_runner = agent_runner
|
||||
self.streaming_service = streaming_service
|
||||
self.push_notification_service = push_notification_service
|
||||
self.db = db
|
||||
self.lock = asyncio.Lock()
|
||||
self.subscriber_lock = asyncio.Lock()
|
||||
self.task_sse_subscribers = {}
|
||||
|
||||
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
|
||||
"""
|
||||
Manipula requisição para obter informações sobre uma tarefa.
|
||||
|
||||
Args:
|
||||
request: Requisição Get Task do A2A
|
||||
|
||||
Returns:
|
||||
Resposta com os detalhes da tarefa
|
||||
"""
|
||||
try:
|
||||
task_id = request.params.id
|
||||
history_length = request.params.historyLength
|
||||
|
||||
# Busca dados da tarefa do cache
|
||||
task_data = await self.redis_cache.get(f"task:{task_id}")
|
||||
|
||||
if not task_data:
|
||||
logger.warning(f"Tarefa não encontrada: {task_id}")
|
||||
return GetTaskResponse(id=request.id, error=TaskNotFoundError())
|
||||
|
||||
# Cria uma instância Task a partir dos dados do cache
|
||||
task = Task.model_validate(task_data)
|
||||
|
||||
# Se o parâmetro historyLength estiver presente, manipula o histórico
|
||||
if history_length is not None and task.history:
|
||||
if history_length == 0:
|
||||
task.history = []
|
||||
elif len(task.history) > history_length:
|
||||
task.history = task.history[-history_length:]
|
||||
|
||||
return GetTaskResponse(id=request.id, result=task)
|
||||
except Exception as e:
|
||||
logger.error(f"Erro ao processar on_get_task: {str(e)}")
|
||||
return GetTaskResponse(id=request.id, error=InternalError(message=str(e)))
|
||||
|
||||
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
|
||||
"""
|
||||
Handle request to cancel a running task.
|
||||
|
||||
Args:
|
||||
request: The JSON-RPC request to cancel a task
|
||||
|
||||
Returns:
|
||||
Response with updated task data or error
|
||||
"""
|
||||
logger.info(f"Cancelling task {request.params.id}")
|
||||
task_id_params = request.params
|
||||
|
||||
try:
|
||||
task_data = await self.redis_cache.get(f"task:{task_id_params.id}")
|
||||
if not task_data:
|
||||
logger.warning(f"Task {task_id_params.id} not found for cancellation")
|
||||
return CancelTaskResponse(id=request.id, error=TaskNotFoundError())
|
||||
|
||||
# Check if task can be cancelled
|
||||
current_state = task_data.get("status", {}).get("state")
|
||||
if current_state not in [TaskState.SUBMITTED, TaskState.WORKING]:
|
||||
logger.warning(
|
||||
f"Task {task_id_params.id} in state {current_state} cannot be cancelled"
|
||||
)
|
||||
return CancelTaskResponse(id=request.id, error=TaskNotCancelableError())
|
||||
|
||||
# Update task status to cancelled
|
||||
task_data["status"] = {
|
||||
"state": TaskState.CANCELED,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": {
|
||||
"role": "agent",
|
||||
"parts": [{"type": "text", "text": "Task cancelled by user"}],
|
||||
},
|
||||
}
|
||||
|
||||
# Save updated task data
|
||||
await self.redis_cache.set(f"task:{task_id_params.id}", task_data)
|
||||
|
||||
# Send push notification if configured
|
||||
await self._send_push_notification_for_task(
|
||||
task_id_params.id, "canceled", system_message="Task cancelled by user"
|
||||
)
|
||||
|
||||
# Publish event to SSE subscribers
|
||||
await self._publish_task_update(
|
||||
task_id_params.id,
|
||||
TaskStatusUpdateEvent(
|
||||
id=task_id_params.id,
|
||||
status=TaskStatus(
|
||||
state=TaskState.CANCELED,
|
||||
timestamp=datetime.now(),
|
||||
message=Message(
|
||||
role="agent",
|
||||
parts=[TextPart(text="Task cancelled by user")],
|
||||
),
|
||||
),
|
||||
final=True,
|
||||
),
|
||||
)
|
||||
|
||||
return CancelTaskResponse(id=request.id, result=task_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling task: {str(e)}", exc_info=True)
|
||||
return CancelTaskResponse(
|
||||
id=request.id,
|
||||
error=InternalError(message=f"Error cancelling task: {str(e)}"),
|
||||
)
|
||||
|
||||
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
|
||||
"""
|
||||
Manipula requisição para enviar uma nova tarefa.
|
||||
|
||||
Args:
|
||||
request: Requisição de envio de tarefa
|
||||
|
||||
Returns:
|
||||
Resposta com os detalhes da tarefa criada
|
||||
"""
|
||||
try:
|
||||
params = request.params
|
||||
task_id = params.id
|
||||
logger.info(f"Recebendo tarefa {task_id}")
|
||||
|
||||
# Verifica se já existe uma tarefa com esse ID
|
||||
existing_task = await self.redis_cache.get(f"task:{task_id}")
|
||||
if existing_task:
|
||||
# Se a tarefa já existe e está em progresso, retorna a tarefa atual
|
||||
if existing_task.get("status", {}).get("state") in [
|
||||
TaskState.WORKING,
|
||||
TaskState.COMPLETED,
|
||||
]:
|
||||
logger.info(
|
||||
f"Tarefa {task_id} já existe e está em progresso/concluída"
|
||||
)
|
||||
return SendTaskResponse(
|
||||
id=request.id, result=Task.model_validate(existing_task)
|
||||
)
|
||||
|
||||
# Se a tarefa existe mas falhou ou foi cancelada, podemos reprocessá-la
|
||||
logger.info(f"Reprocessando tarefa existente {task_id}")
|
||||
|
||||
# Verifica compatibilidade de modalidades
|
||||
server_output_modes = []
|
||||
if self.agent_runner:
|
||||
# Tenta obter modos suportados do agente
|
||||
try:
|
||||
server_output_modes = await self.agent_runner.get_supported_modes()
|
||||
except Exception as e:
|
||||
logger.warning(f"Erro ao obter modos suportados: {str(e)}")
|
||||
server_output_modes = ["text"] # Fallback para texto
|
||||
|
||||
if not are_modalities_compatible(
|
||||
server_output_modes, params.acceptedOutputModes
|
||||
):
|
||||
logger.warning(
|
||||
f"Modos incompatíveis: servidor={server_output_modes}, cliente={params.acceptedOutputModes}"
|
||||
)
|
||||
return SendTaskResponse(
|
||||
id=request.id, error=ContentTypeNotSupportedError()
|
||||
)
|
||||
|
||||
# Cria dados da tarefa
|
||||
task_data = await self._create_task_data(params)
|
||||
|
||||
# Armazena a tarefa no cache
|
||||
await self.redis_cache.set(f"task:{task_id}", task_data)
|
||||
|
||||
# Configura notificações push, se fornecidas
|
||||
if params.pushNotification:
|
||||
await self.redis_cache.set(
|
||||
f"task_notification:{task_id}", params.pushNotification.model_dump()
|
||||
)
|
||||
|
||||
# Inicia a execução da tarefa em background
|
||||
asyncio.create_task(self._execute_task(task_data, params))
|
||||
|
||||
# Converte para objeto Task e retorna
|
||||
task = Task.model_validate(task_data)
|
||||
return SendTaskResponse(id=request.id, result=task)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erro ao processar on_send_task: {str(e)}")
|
||||
return SendTaskResponse(id=request.id, error=InternalError(message=str(e)))
|
||||
|
||||
async def on_send_task_subscribe(
|
||||
self, request: SendTaskStreamingRequest
|
||||
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
|
||||
"""
|
||||
Handle request to send a task and subscribe to streaming updates.
|
||||
|
||||
Args:
|
||||
request: The JSON-RPC request to send a task with streaming
|
||||
|
||||
Returns:
|
||||
Stream of events or error response
|
||||
"""
|
||||
logger.info(f"Sending task with streaming {request.params.id}")
|
||||
task_send_params = request.params
|
||||
|
||||
try:
|
||||
# Check output mode compatibility
|
||||
if not are_modalities_compatible(
|
||||
["text", "application/json"], # Default supported modes
|
||||
task_send_params.acceptedOutputModes,
|
||||
):
|
||||
return new_incompatible_types_error(request.id)
|
||||
|
||||
# Create initial task data
|
||||
task_data = await self._create_task_data(task_send_params)
|
||||
|
||||
# Setup SSE consumer
|
||||
sse_queue = await self._setup_sse_consumer(task_send_params.id)
|
||||
|
||||
# Execute task asynchronously (fire and forget)
|
||||
asyncio.create_task(self._execute_task(task_data, task_send_params))
|
||||
|
||||
# Return generator to dequeue events for SSE
|
||||
return self._dequeue_events_for_sse(
|
||||
request.id, task_send_params.id, sse_queue
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting up streaming task: {str(e)}", exc_info=True)
|
||||
return SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
error=InternalError(
|
||||
message=f"Error setting up streaming task: {str(e)}"
|
||||
),
|
||||
)
|
||||
|
||||
async def on_set_task_push_notification(
|
||||
self, request: SetTaskPushNotificationRequest
|
||||
) -> SetTaskPushNotificationResponse:
|
||||
"""
|
||||
Configure push notifications for a task.
|
||||
|
||||
Args:
|
||||
request: The JSON-RPC request to set push notification
|
||||
|
||||
Returns:
|
||||
Response with configuration or error
|
||||
"""
|
||||
logger.info(f"Setting push notification for task {request.params.id}")
|
||||
task_notification_params = request.params
|
||||
|
||||
try:
|
||||
if not self.push_notification_service:
|
||||
logger.warning("Push notifications not supported")
|
||||
return SetTaskPushNotificationResponse(
|
||||
id=request.id, error=PushNotificationNotSupportedError()
|
||||
)
|
||||
|
||||
# Check if task exists
|
||||
task_data = await self.redis_cache.get(
|
||||
f"task:{task_notification_params.id}"
|
||||
)
|
||||
if not task_data:
|
||||
logger.warning(
|
||||
f"Task {task_notification_params.id} not found for setting push notification"
|
||||
)
|
||||
return SetTaskPushNotificationResponse(
|
||||
id=request.id, error=TaskNotFoundError()
|
||||
)
|
||||
|
||||
# Save push notification config
|
||||
config = {
|
||||
"url": task_notification_params.pushNotificationConfig.url,
|
||||
"headers": {}, # Add auth headers if needed
|
||||
}
|
||||
|
||||
await self.redis_cache.set(
|
||||
f"task:{task_notification_params.id}:push", config
|
||||
)
|
||||
|
||||
return SetTaskPushNotificationResponse(
|
||||
id=request.id, result=task_notification_params
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting push notification: {str(e)}", exc_info=True)
|
||||
return SetTaskPushNotificationResponse(
|
||||
id=request.id,
|
||||
error=InternalError(
|
||||
message=f"Error setting push notification: {str(e)}"
|
||||
),
|
||||
)
|
||||
|
||||
async def on_get_task_push_notification(
|
||||
self, request: GetTaskPushNotificationRequest
|
||||
) -> GetTaskPushNotificationResponse:
|
||||
"""
|
||||
Get push notification configuration for a task.
|
||||
|
||||
Args:
|
||||
request: The JSON-RPC request to get push notification config
|
||||
|
||||
Returns:
|
||||
Response with configuration or error
|
||||
"""
|
||||
logger.info(f"Getting push notification for task {request.params.id}")
|
||||
task_params = request.params
|
||||
|
||||
try:
|
||||
if not self.push_notification_service:
|
||||
logger.warning("Push notifications not supported")
|
||||
return GetTaskPushNotificationResponse(
|
||||
id=request.id, error=PushNotificationNotSupportedError()
|
||||
)
|
||||
|
||||
# Check if task exists
|
||||
task_data = await self.redis_cache.get(f"task:{task_params.id}")
|
||||
if not task_data:
|
||||
logger.warning(
|
||||
f"Task {task_params.id} not found for getting push notification"
|
||||
)
|
||||
return GetTaskPushNotificationResponse(
|
||||
id=request.id, error=TaskNotFoundError()
|
||||
)
|
||||
|
||||
# Get push notification config
|
||||
config = await self.redis_cache.get(f"task:{task_params.id}:push")
|
||||
if not config:
|
||||
logger.warning(f"No push notification config for task {task_params.id}")
|
||||
return GetTaskPushNotificationResponse(
|
||||
id=request.id,
|
||||
error=InternalError(
|
||||
message="No push notification configuration found"
|
||||
),
|
||||
)
|
||||
|
||||
result = TaskPushNotificationConfig(
|
||||
id=task_params.id,
|
||||
pushNotificationConfig=PushNotificationConfig(
|
||||
url=config.get("url"), token=None, authentication=None
|
||||
),
|
||||
)
|
||||
|
||||
return GetTaskPushNotificationResponse(id=request.id, result=result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting push notification: {str(e)}", exc_info=True)
|
||||
return GetTaskPushNotificationResponse(
|
||||
id=request.id,
|
||||
error=InternalError(
|
||||
message=f"Error getting push notification: {str(e)}"
|
||||
),
|
||||
)
|
||||
|
||||
async def on_resubscribe_to_task(
|
||||
self, request: TaskResubscriptionRequest
|
||||
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
|
||||
"""
|
||||
Resubscribe to a task's streaming events.
|
||||
|
||||
Args:
|
||||
request: The JSON-RPC request to resubscribe
|
||||
|
||||
Returns:
|
||||
Stream of events or error response
|
||||
"""
|
||||
logger.info(f"Resubscribing to task {request.params.id}")
|
||||
task_params = request.params
|
||||
|
||||
try:
|
||||
# Check if task exists
|
||||
task_data = await self.redis_cache.get(f"task:{task_params.id}")
|
||||
if not task_data:
|
||||
logger.warning(f"Task {task_params.id} not found for resubscription")
|
||||
return JSONRPCResponse(id=request.id, error=TaskNotFoundError())
|
||||
|
||||
# Setup SSE consumer with resubscribe flag
|
||||
try:
|
||||
sse_queue = await self._setup_sse_consumer(
|
||||
task_params.id, is_resubscribe=True
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Task {task_params.id} not available for resubscription"
|
||||
)
|
||||
return JSONRPCResponse(
|
||||
id=request.id,
|
||||
error=InternalError(
|
||||
message="Task not available for resubscription"
|
||||
),
|
||||
)
|
||||
|
||||
# Send initial status update to the new subscriber
|
||||
status = task_data.get("status", {})
|
||||
final = status.get("state") in [
|
||||
TaskState.COMPLETED,
|
||||
TaskState.FAILED,
|
||||
TaskState.CANCELED,
|
||||
]
|
||||
|
||||
# Convert to TaskStatus object
|
||||
task_status = TaskStatus(
|
||||
state=status.get("state", TaskState.UNKNOWN),
|
||||
timestamp=datetime.fromisoformat(
|
||||
status.get("timestamp", datetime.now().isoformat())
|
||||
),
|
||||
message=status.get("message"),
|
||||
)
|
||||
|
||||
# Publish to the specific queue
|
||||
await sse_queue.put(
|
||||
TaskStatusUpdateEvent(
|
||||
id=task_params.id, status=task_status, final=final
|
||||
)
|
||||
)
|
||||
|
||||
# Return generator to dequeue events for SSE
|
||||
return self._dequeue_events_for_sse(request.id, task_params.id, sse_queue)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error resubscribing to task: {str(e)}", exc_info=True)
|
||||
return JSONRPCResponse(
|
||||
id=request.id,
|
||||
error=InternalError(message=f"Error resubscribing to task: {str(e)}"),
|
||||
)
|
||||
|
||||
async def _create_task_data(self, params: TaskSendParams) -> Dict[str, Any]:
|
||||
"""
|
||||
Create initial task data structure.
|
||||
|
||||
Args:
|
||||
params: Task send parameters
|
||||
|
||||
Returns:
|
||||
Task data dictionary
|
||||
"""
|
||||
# Create task with initial status
|
||||
task_data = {
|
||||
"id": params.id,
|
||||
"sessionId": params.sessionId,
|
||||
"status": {
|
||||
"state": TaskState.SUBMITTED,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": None,
|
||||
"error": None,
|
||||
},
|
||||
"artifacts": [],
|
||||
"history": [params.message.model_dump()],
|
||||
"metadata": params.metadata or {},
|
||||
}
|
||||
|
||||
# Save task data
|
||||
await self.redis_cache.set(f"task:{params.id}", task_data)
|
||||
|
||||
return task_data
|
||||
|
||||
async def _execute_task(self, task: Dict[str, Any], params: TaskSendParams) -> None:
|
||||
"""
|
||||
Executa uma tarefa usando o adaptador do agente.
|
||||
|
||||
Esta função é responsável pela execução real da tarefa pelo agente,
|
||||
atualizando seu status conforme o progresso.
|
||||
|
||||
Args:
|
||||
task: Dados da tarefa a ser executada
|
||||
params: Parâmetros de envio da tarefa
|
||||
"""
|
||||
task_id = task["id"]
|
||||
agent_id = params.agentId
|
||||
message_text = ""
|
||||
|
||||
# Extrai o texto da mensagem
|
||||
if params.message and params.message.parts:
|
||||
for part in params.message.parts:
|
||||
if part.type == "text":
|
||||
message_text += part.text
|
||||
|
||||
if not message_text:
|
||||
await self._update_task_status(
|
||||
task_id, TaskState.FAILED, "Mensagem não contém texto", final=True
|
||||
)
|
||||
return
|
||||
|
||||
# Verificamos se é uma execução em andamento
|
||||
task_status = task.get("status", {})
|
||||
if task_status.get("state") in [TaskState.WORKING, TaskState.COMPLETED]:
|
||||
logger.info(f"Tarefa {task_id} já está em execução ou concluída")
|
||||
return
|
||||
|
||||
try:
|
||||
# Atualiza para estado "working"
|
||||
await self._update_task_status(
|
||||
task_id, TaskState.WORKING, "Processando solicitação"
|
||||
)
|
||||
|
||||
# Executa o agente
|
||||
if self.agent_runner:
|
||||
response = await self.agent_runner.run_agent(
|
||||
agent_id=agent_id,
|
||||
message=message_text,
|
||||
session_id=params.sessionId,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
# Processa a resposta do agente
|
||||
if response and isinstance(response, dict):
|
||||
# Extrai texto da resposta
|
||||
response_text = response.get("text", "")
|
||||
if not response_text and "message" in response:
|
||||
message = response.get("message", {})
|
||||
parts = message.get("parts", [])
|
||||
for part in parts:
|
||||
if part.get("type") == "text":
|
||||
response_text += part.get("text", "")
|
||||
|
||||
# Constrói a mensagem final do agente
|
||||
if response_text:
|
||||
# Cria um artefato para a resposta
|
||||
artifact = Artifact(
|
||||
name="response",
|
||||
parts=[TextPart(text=response_text)],
|
||||
index=0,
|
||||
lastChunk=True,
|
||||
)
|
||||
|
||||
# Adiciona o artefato à tarefa
|
||||
await self._add_task_artifact(task_id, artifact)
|
||||
|
||||
# Atualiza o status da tarefa para completado
|
||||
await self._update_task_status(
|
||||
task_id, TaskState.COMPLETED, response_text, final=True
|
||||
)
|
||||
else:
|
||||
await self._update_task_status(
|
||||
task_id,
|
||||
TaskState.FAILED,
|
||||
"O agente não retornou uma resposta válida",
|
||||
final=True,
|
||||
)
|
||||
else:
|
||||
await self._update_task_status(
|
||||
task_id,
|
||||
TaskState.FAILED,
|
||||
"Resposta inválida do agente",
|
||||
final=True,
|
||||
)
|
||||
else:
|
||||
await self._update_task_status(
|
||||
task_id,
|
||||
TaskState.FAILED,
|
||||
"Adaptador do agente não configurado",
|
||||
final=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Erro na execução da tarefa {task_id}: {str(e)}")
|
||||
await self._update_task_status(
|
||||
task_id, TaskState.FAILED, f"Erro ao processar: {str(e)}", final=True
|
||||
)
|
||||
|
||||
async def _update_task_status(
|
||||
self, task_id: str, state: TaskState, message_text: str, final: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Atualiza o status de uma tarefa.
|
||||
|
||||
Args:
|
||||
task_id: ID da tarefa a ser atualizada
|
||||
state: Novo estado da tarefa
|
||||
message_text: Texto da mensagem associada ao status
|
||||
final: Indica se este é o status final da tarefa
|
||||
"""
|
||||
try:
|
||||
# Busca dados atuais da tarefa
|
||||
task_data = await self.redis_cache.get(f"task:{task_id}")
|
||||
if not task_data:
|
||||
logger.warning(
|
||||
f"Não foi possível atualizar status: tarefa {task_id} não encontrada"
|
||||
)
|
||||
return
|
||||
|
||||
# Cria objeto de status com a mensagem
|
||||
agent_message = Message(
|
||||
role="agent",
|
||||
parts=[TextPart(text=message_text)],
|
||||
metadata={"timestamp": datetime.now().isoformat()},
|
||||
)
|
||||
|
||||
status = TaskStatus(
|
||||
state=state, message=agent_message, timestamp=datetime.now()
|
||||
)
|
||||
|
||||
# Atualiza o status na tarefa
|
||||
task_data["status"] = status.model_dump(exclude_none=True)
|
||||
|
||||
# Atualiza o histórico, se existir
|
||||
if "history" not in task_data:
|
||||
task_data["history"] = []
|
||||
|
||||
# Adiciona a mensagem ao histórico
|
||||
task_data["history"].append(agent_message.model_dump(exclude_none=True))
|
||||
|
||||
# Armazena a tarefa atualizada
|
||||
await self.redis_cache.set(f"task:{task_id}", task_data)
|
||||
|
||||
# Cria evento de atualização de status
|
||||
status_event = TaskStatusUpdateEvent(id=task_id, status=status, final=final)
|
||||
|
||||
# Publica atualização
|
||||
await self._publish_task_update(task_id, status_event)
|
||||
|
||||
# Envia notificação push, se configurada
|
||||
if final or state in [
|
||||
TaskState.FAILED,
|
||||
TaskState.COMPLETED,
|
||||
TaskState.CANCELED,
|
||||
]:
|
||||
await self._send_push_notification_for_task(
|
||||
task_id=task_id, state=state, message_text=message_text
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Erro ao atualizar status da tarefa {task_id}: {str(e)}")
|
||||
|
||||
async def _add_task_artifact(self, task_id: str, artifact: Artifact) -> None:
|
||||
"""
|
||||
Add an artifact to a task and publish the update.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
artifact: Artifact to add
|
||||
"""
|
||||
logger.info(f"Adding artifact to task {task_id}")
|
||||
|
||||
# Update task data
|
||||
task_data = await self.redis_cache.get(f"task:{task_id}")
|
||||
if task_data:
|
||||
if "artifacts" not in task_data:
|
||||
task_data["artifacts"] = []
|
||||
|
||||
# Convert artifact to dict
|
||||
artifact_dict = artifact.model_dump()
|
||||
task_data["artifacts"].append(artifact_dict)
|
||||
await self.redis_cache.set(f"task:{task_id}", task_data)
|
||||
|
||||
# Create artifact update event
|
||||
event = TaskArtifactUpdateEvent(id=task_id, artifact=artifact)
|
||||
|
||||
# Publish event
|
||||
await self._publish_task_update(task_id, event)
|
||||
|
||||
async def _publish_task_update(
|
||||
self, task_id: str, event: Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent]
|
||||
) -> None:
|
||||
"""
|
||||
Publish a task update event to all subscribers.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
event: Event to publish
|
||||
"""
|
||||
async with self.subscriber_lock:
|
||||
if task_id not in self.task_sse_subscribers:
|
||||
return
|
||||
|
||||
subscribers = self.task_sse_subscribers[task_id]
|
||||
for subscriber in subscribers:
|
||||
try:
|
||||
await subscriber.put(event)
|
||||
except Exception as e:
|
||||
logger.error(f"Error publishing event to subscriber: {str(e)}")
|
||||
|
||||
async def _send_push_notification_for_task(
|
||||
self,
|
||||
task_id: str,
|
||||
state: str,
|
||||
message_text: str = None,
|
||||
system_message: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send push notification for a task if configured.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
state: Task state
|
||||
message_text: Optional message text
|
||||
system_message: Optional system message
|
||||
"""
|
||||
if not self.push_notification_service:
|
||||
return
|
||||
|
||||
try:
|
||||
# Get push notification config
|
||||
config = await self.redis_cache.get(f"task:{task_id}:push")
|
||||
if not config:
|
||||
return
|
||||
|
||||
# Create message if provided
|
||||
message = None
|
||||
if message_text:
|
||||
message = {
|
||||
"role": "agent",
|
||||
"parts": [{"type": "text", "text": message_text}],
|
||||
}
|
||||
elif system_message:
|
||||
# We use 'agent' instead of 'system' since Message only accepts 'user' or 'agent'
|
||||
message = {
|
||||
"role": "agent",
|
||||
"parts": [{"type": "text", "text": system_message}],
|
||||
}
|
||||
|
||||
# Send notification
|
||||
await self.push_notification_service.send_notification(
|
||||
url=config["url"],
|
||||
task_id=task_id,
|
||||
state=state,
|
||||
message=message,
|
||||
headers=config.get("headers", {}),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error sending push notification for task {task_id}: {str(e)}"
|
||||
)
|
||||
|
||||
async def _setup_sse_consumer(
|
||||
self, task_id: str, is_resubscribe: bool = False
|
||||
) -> asyncio.Queue:
|
||||
"""
|
||||
Set up an SSE consumer queue for a task.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
is_resubscribe: Whether this is a resubscription
|
||||
|
||||
Returns:
|
||||
Queue for events
|
||||
|
||||
Raises:
|
||||
ValueError: If resubscribing to non-existent task
|
||||
"""
|
||||
async with self.subscriber_lock:
|
||||
if task_id not in self.task_sse_subscribers:
|
||||
if is_resubscribe:
|
||||
raise ValueError("Task not found for resubscription")
|
||||
self.task_sse_subscribers[task_id] = []
|
||||
|
||||
queue = asyncio.Queue()
|
||||
self.task_sse_subscribers[task_id].append(queue)
|
||||
return queue
|
||||
|
||||
async def _dequeue_events_for_sse(
|
||||
self, request_id: str, task_id: str, event_queue: asyncio.Queue
|
||||
) -> AsyncIterable[SendTaskStreamingResponse]:
|
||||
"""
|
||||
Dequeue and yield events for SSE streaming.
|
||||
|
||||
Args:
|
||||
request_id: Request ID
|
||||
task_id: Task ID
|
||||
event_queue: Queue for events
|
||||
|
||||
Yields:
|
||||
SSE events wrapped in SendTaskStreamingResponse
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
event = await event_queue.get()
|
||||
|
||||
if isinstance(event, JSONRPCError):
|
||||
yield SendTaskStreamingResponse(id=request_id, error=event)
|
||||
break
|
||||
|
||||
yield SendTaskStreamingResponse(id=request_id, result=event)
|
||||
|
||||
# Check if this is the final event
|
||||
is_final = False
|
||||
if hasattr(event, "final") and event.final:
|
||||
is_final = True
|
||||
|
||||
if is_final:
|
||||
break
|
||||
finally:
|
||||
# Clean up the subscription when done
|
||||
async with self.subscriber_lock:
|
||||
if task_id in self.task_sse_subscribers:
|
||||
try:
|
||||
self.task_sse_subscribers[task_id].remove(event_queue)
|
||||
# Remove the task from the dict if no more subscribers
|
||||
if not self.task_sse_subscribers[task_id]:
|
||||
del self.task_sse_subscribers[task_id]
|
||||
except ValueError:
|
||||
pass # Queue might have been removed already
|
@ -3,7 +3,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from fastapi import HTTPException, status
|
||||
from src.models.models import Agent
|
||||
from src.schemas.schemas import AgentCreate
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from src.services.mcp_server_service import get_mcp_server
|
||||
import uuid
|
||||
import logging
|
||||
@ -20,9 +20,17 @@ def validate_sub_agents(db: Session, sub_agents: List[uuid.UUID]) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def get_agent(db: Session, agent_id: uuid.UUID) -> Optional[Agent]:
|
||||
def get_agent(db: Session, agent_id: Union[uuid.UUID, str]) -> Optional[Agent]:
|
||||
"""Search for an agent by ID"""
|
||||
try:
|
||||
# Convert to UUID if it's a string
|
||||
if isinstance(agent_id, str):
|
||||
try:
|
||||
agent_id = uuid.UUID(agent_id)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid agent ID: {agent_id}")
|
||||
return None
|
||||
|
||||
agent = db.query(Agent).filter(Agent.id == agent_id).first()
|
||||
if not agent:
|
||||
logger.warning(f"Agent not found: {agent_id}")
|
||||
|
265
src/services/push_notification_auth_service.py
Normal file
265
src/services/push_notification_auth_service.py
Normal file
@ -0,0 +1,265 @@
|
||||
"""
|
||||
Push Notification Authentication Service.
|
||||
|
||||
This service implements JWT authentication for A2A push notifications,
|
||||
allowing agents to send authenticated notifications and clients to verify
|
||||
the authenticity of received notifications.
|
||||
"""
|
||||
|
||||
from jwcrypto import jwk
|
||||
import uuid
|
||||
import time
|
||||
import json
|
||||
import hashlib
|
||||
import httpx
|
||||
import logging
|
||||
import jwt
|
||||
from jwt import PyJWK, PyJWKClient
|
||||
from fastapi import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
AUTH_HEADER_PREFIX = "Bearer "
|
||||
|
||||
|
||||
class PushNotificationAuth:
|
||||
"""
|
||||
Base class for push notification authentication.
|
||||
|
||||
Contains common methods used by both the sender and the receiver
|
||||
of push notifications.
|
||||
"""
|
||||
|
||||
def _calculate_request_body_sha256(self, data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Calculates the SHA256 hash of the request body.
|
||||
|
||||
This logic needs to be the same for the agent that signs the payload
|
||||
and for the client that verifies it.
|
||||
|
||||
Args:
|
||||
data: Request body data
|
||||
|
||||
Returns:
|
||||
SHA256 hash as a hexadecimal string
|
||||
"""
|
||||
body_str = json.dumps(
|
||||
data,
|
||||
ensure_ascii=False,
|
||||
allow_nan=False,
|
||||
indent=None,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
return hashlib.sha256(body_str.encode()).hexdigest()
|
||||
|
||||
|
||||
class PushNotificationSenderAuth(PushNotificationAuth):
|
||||
"""
|
||||
Authentication for the push notification sender.
|
||||
|
||||
This class is used by the A2A server to authenticate notifications
|
||||
sent to callback URLs of clients.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the push notification sender authentication service.
|
||||
"""
|
||||
self.public_keys = []
|
||||
self.private_key_jwk = None
|
||||
|
||||
@staticmethod
|
||||
async def verify_push_notification_url(url: str) -> bool:
|
||||
"""
|
||||
Verifies if a push notification URL is valid and responds correctly.
|
||||
|
||||
Sends a validation token and verifies if the response contains the same token.
|
||||
|
||||
Args:
|
||||
url: URL to be verified
|
||||
|
||||
Returns:
|
||||
True if the URL is verified successfully, False otherwise
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
try:
|
||||
validation_token = str(uuid.uuid4())
|
||||
response = await client.get(
|
||||
url, params={"validationToken": validation_token}
|
||||
)
|
||||
response.raise_for_status()
|
||||
is_verified = response.text == validation_token
|
||||
|
||||
logger.info(f"Push notification URL verified: {url} => {is_verified}")
|
||||
return is_verified
|
||||
except Exception as e:
|
||||
logger.warning(f"Error verifying push notification URL {url}: {e}")
|
||||
|
||||
return False
|
||||
|
||||
def generate_jwk(self):
|
||||
"""
|
||||
Generates a new JWK pair for signing.
|
||||
|
||||
The key pair is used to sign push notifications.
|
||||
The public key is available via the JWKS endpoint.
|
||||
"""
|
||||
key = jwk.JWK.generate(kty="RSA", size=2048, kid=str(uuid.uuid4()), use="sig")
|
||||
self.public_keys.append(key.export_public(as_dict=True))
|
||||
self.private_key_jwk = PyJWK.from_json(key.export_private())
|
||||
|
||||
def handle_jwks_endpoint(self, _request: Request) -> JSONResponse:
|
||||
"""
|
||||
Handles the JWKS endpoint to allow clients to obtain the public keys.
|
||||
|
||||
Args:
|
||||
_request: HTTP request
|
||||
|
||||
Returns:
|
||||
JSON response with the public keys
|
||||
"""
|
||||
return JSONResponse({"keys": self.public_keys})
|
||||
|
||||
def _generate_jwt(self, data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Generates a JWT token by signing the SHA256 hash of the payload and the timestamp.
|
||||
|
||||
The payload is signed with the private key to ensure integrity.
|
||||
The timestamp (iat) prevents replay attacks.
|
||||
|
||||
Args:
|
||||
data: Payload data
|
||||
|
||||
Returns:
|
||||
Signed JWT token
|
||||
"""
|
||||
iat = int(time.time())
|
||||
|
||||
return jwt.encode(
|
||||
{
|
||||
"iat": iat,
|
||||
"request_body_sha256": self._calculate_request_body_sha256(data),
|
||||
},
|
||||
key=self.private_key_jwk.key,
|
||||
headers={"kid": self.private_key_jwk.key_id},
|
||||
algorithm="RS256",
|
||||
)
|
||||
|
||||
async def send_push_notification(self, url: str, data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Sends an authenticated push notification to the specified URL.
|
||||
|
||||
Args:
|
||||
url: URL to send the notification
|
||||
data: Notification data
|
||||
|
||||
Returns:
|
||||
True if the notification was sent successfully, False otherwise
|
||||
"""
|
||||
if not self.private_key_jwk:
|
||||
logger.error(
|
||||
"Attempt to send push notification without generating JWK keys"
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
jwt_token = self._generate_jwt(data)
|
||||
headers = {"Authorization": f"Bearer {jwt_token}"}
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.post(url, json=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
logger.info(f"Push notification sent to URL: {url}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending push notification to URL {url}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class PushNotificationReceiverAuth(PushNotificationAuth):
|
||||
"""
|
||||
Authentication for the push notification receiver.
|
||||
|
||||
This class is used by clients to verify the authenticity
|
||||
of notifications received from the A2A server.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the push notification receiver authentication service.
|
||||
"""
|
||||
self.jwks_client = None
|
||||
|
||||
async def load_jwks(self, jwks_url: str):
|
||||
"""
|
||||
Loads the public JWKS keys from a URL.
|
||||
|
||||
Args:
|
||||
jwks_url: URL of the JWKS endpoint
|
||||
"""
|
||||
self.jwks_client = PyJWKClient(jwks_url)
|
||||
|
||||
async def verify_push_notification(self, request: Request) -> bool:
|
||||
"""
|
||||
Verifies the authenticity of a push notification.
|
||||
|
||||
Args:
|
||||
request: HTTP request containing the notification
|
||||
|
||||
Returns:
|
||||
True if the notification is authentic, False otherwise
|
||||
|
||||
Raises:
|
||||
ValueError: If the token is invalid or expired
|
||||
"""
|
||||
if not self.jwks_client:
|
||||
logger.error("Attempt to verify notification without loading JWKS keys")
|
||||
return False
|
||||
|
||||
# Verify authentication header
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX):
|
||||
logger.warning("Invalid authorization header")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Extract and verify token
|
||||
token = auth_header[len(AUTH_HEADER_PREFIX) :]
|
||||
signing_key = self.jwks_client.get_signing_key_from_jwt(token)
|
||||
|
||||
# Decode token
|
||||
decode_token = jwt.decode(
|
||||
token,
|
||||
signing_key.key,
|
||||
options={"require": ["iat", "request_body_sha256"]},
|
||||
algorithms=["RS256"],
|
||||
)
|
||||
|
||||
# Verify request body integrity
|
||||
body_data = await request.json()
|
||||
actual_body_sha256 = self._calculate_request_body_sha256(body_data)
|
||||
if actual_body_sha256 != decode_token["request_body_sha256"]:
|
||||
# The payload signature does not match the hash in the signed token
|
||||
logger.warning("Request body hash does not match the token")
|
||||
raise ValueError("Invalid request body")
|
||||
|
||||
# Verify token age (maximum 5 minutes)
|
||||
if time.time() - decode_token["iat"] > 60 * 5:
|
||||
# Do not allow notifications older than 5 minutes
|
||||
# This prevents replay attacks
|
||||
logger.warning("Token expired")
|
||||
raise ValueError("Token expired")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying push notification: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# Global instance of the push notification sender authentication service
|
||||
push_notification_auth = PushNotificationSenderAuth()
|
||||
|
||||
# Generate JWK keys on initialization
|
||||
push_notification_auth.generate_jwk()
|
@ -4,6 +4,8 @@ from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
import asyncio
|
||||
|
||||
from src.services.push_notification_auth_service import push_notification_auth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -20,10 +22,24 @@ class PushNotificationService:
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
use_jwt_auth: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Envia uma notificação push para a URL especificada.
|
||||
Implementa retry com backoff exponencial.
|
||||
Send a push notification to the specified URL.
|
||||
Implements exponential backoff retry.
|
||||
|
||||
Args:
|
||||
url: URL to send the notification
|
||||
task_id: Task ID
|
||||
state: Task state
|
||||
message: Optional message
|
||||
headers: Optional HTTP headers
|
||||
max_retries: Maximum number of retries
|
||||
retry_delay: Initial delay between retries
|
||||
use_jwt_auth: Whether to use JWT authentication
|
||||
|
||||
Returns:
|
||||
True if the notification was sent successfully, False otherwise
|
||||
"""
|
||||
payload = {
|
||||
"taskId": task_id,
|
||||
@ -32,18 +48,37 @@ class PushNotificationService:
|
||||
"message": message,
|
||||
}
|
||||
|
||||
# First URL verification
|
||||
if use_jwt_auth:
|
||||
is_url_valid = await push_notification_auth.verify_push_notification_url(
|
||||
url
|
||||
)
|
||||
if not is_url_valid:
|
||||
logger.warning(f"Invalid push notification URL: {url}")
|
||||
return False
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with self.session.post(
|
||||
url, json=payload, headers=headers or {}, timeout=10
|
||||
) as response:
|
||||
if response.status in (200, 201, 202, 204):
|
||||
if use_jwt_auth:
|
||||
# Use JWT authentication
|
||||
success = await push_notification_auth.send_push_notification(
|
||||
url, payload
|
||||
)
|
||||
if success:
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
f"Push notification failed with status {response.status}. "
|
||||
f"Attempt {attempt + 1}/{max_retries}"
|
||||
)
|
||||
else:
|
||||
# Legacy method without JWT authentication
|
||||
async with self.session.post(
|
||||
url, json=payload, headers=headers or {}, timeout=10
|
||||
) as response:
|
||||
if response.status in (200, 201, 202, 204):
|
||||
logger.info(f"Push notification sent to URL: {url}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to send push notification with status {response.status}. "
|
||||
f"Attempt {attempt + 1}/{max_retries}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error sending push notification: {str(e)}. "
|
||||
@ -56,9 +91,9 @@ class PushNotificationService:
|
||||
return False
|
||||
|
||||
async def close(self):
|
||||
"""Fecha a sessão HTTP"""
|
||||
"""Close the HTTP session"""
|
||||
await self.session.close()
|
||||
|
||||
|
||||
# Instância global do serviço
|
||||
# Global instance of the push notification service
|
||||
push_notification_service = PushNotificationService()
|
||||
|
556
src/services/redis_cache_service.py
Normal file
556
src/services/redis_cache_service.py
Normal file
@ -0,0 +1,556 @@
|
||||
"""
|
||||
Cache Redis service for the A2A protocol.
|
||||
|
||||
This service provides an interface for storing and retrieving data related to tasks,
|
||||
push notification configurations, and other A2A-related data.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import asyncio
|
||||
import redis.asyncio as aioredis
|
||||
from redis.exceptions import RedisError
|
||||
from src.config.redis import get_redis_config, get_a2a_config
|
||||
import threading
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _InMemoryCacheFallback:
|
||||
"""
|
||||
Fallback in-memory cache implementation for when Redis is not available.
|
||||
|
||||
This should only be used for development or testing environments.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
"""Singleton implementation."""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize cache storage."""
|
||||
if not getattr(self, "_initialized", False):
|
||||
with self._lock:
|
||||
if not getattr(self, "_initialized", False):
|
||||
self._data = {}
|
||||
self._ttls = {}
|
||||
self._hash_data = {}
|
||||
self._list_data = {}
|
||||
self._data_lock = threading.Lock()
|
||||
self._initialized = True
|
||||
logger.warning(
|
||||
"Initializing in-memory cache fallback (not for production)"
|
||||
)
|
||||
|
||||
async def set(self, key, value, ex=None):
|
||||
"""Set a key with optional expiration."""
|
||||
with self._data_lock:
|
||||
self._data[key] = value
|
||||
if ex is not None:
|
||||
self._ttls[key] = time.time() + ex
|
||||
elif key in self._ttls:
|
||||
del self._ttls[key]
|
||||
return True
|
||||
|
||||
async def setex(self, key, ex, value):
|
||||
"""Set a key with expiration."""
|
||||
return await self.set(key, value, ex)
|
||||
|
||||
async def get(self, key):
|
||||
"""Get a key value."""
|
||||
with self._data_lock:
|
||||
# Check if expired
|
||||
if key in self._ttls and time.time() > self._ttls[key]:
|
||||
del self._data[key]
|
||||
del self._ttls[key]
|
||||
return None
|
||||
return self._data.get(key)
|
||||
|
||||
async def delete(self, key):
|
||||
"""Delete a key."""
|
||||
with self._data_lock:
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
if key in self._ttls:
|
||||
del self._ttls[key]
|
||||
return 1
|
||||
return 0
|
||||
|
||||
async def exists(self, key):
|
||||
"""Check if key exists."""
|
||||
with self._data_lock:
|
||||
if key in self._ttls and time.time() > self._ttls[key]:
|
||||
del self._data[key]
|
||||
del self._ttls[key]
|
||||
return 0
|
||||
return 1 if key in self._data else 0
|
||||
|
||||
async def hset(self, key, field, value):
|
||||
"""Set a hash field."""
|
||||
with self._data_lock:
|
||||
if key not in self._hash_data:
|
||||
self._hash_data[key] = {}
|
||||
self._hash_data[key][field] = value
|
||||
return 1
|
||||
|
||||
async def hget(self, key, field):
|
||||
"""Get a hash field."""
|
||||
with self._data_lock:
|
||||
if key not in self._hash_data:
|
||||
return None
|
||||
return self._hash_data[key].get(field)
|
||||
|
||||
async def hdel(self, key, field):
|
||||
"""Delete a hash field."""
|
||||
with self._data_lock:
|
||||
if key in self._hash_data and field in self._hash_data[key]:
|
||||
del self._hash_data[key][field]
|
||||
return 1
|
||||
return 0
|
||||
|
||||
async def hgetall(self, key):
|
||||
"""Get all hash fields."""
|
||||
with self._data_lock:
|
||||
if key not in self._hash_data:
|
||||
return {}
|
||||
return dict(self._hash_data[key])
|
||||
|
||||
async def rpush(self, key, value):
|
||||
"""Add to a list."""
|
||||
with self._data_lock:
|
||||
if key not in self._list_data:
|
||||
self._list_data[key] = []
|
||||
self._list_data[key].append(value)
|
||||
return len(self._list_data[key])
|
||||
|
||||
async def lrange(self, key, start, end):
|
||||
"""Get range from list."""
|
||||
with self._data_lock:
|
||||
if key not in self._list_data:
|
||||
return []
|
||||
|
||||
# Handle negative indices
|
||||
if end < 0:
|
||||
end = len(self._list_data[key]) + end + 1
|
||||
|
||||
return self._list_data[key][start:end]
|
||||
|
||||
async def expire(self, key, seconds):
|
||||
"""Set expiration on key."""
|
||||
with self._data_lock:
|
||||
if key in self._data:
|
||||
self._ttls[key] = time.time() + seconds
|
||||
return 1
|
||||
return 0
|
||||
|
||||
async def flushdb(self):
|
||||
"""Clear all data."""
|
||||
with self._data_lock:
|
||||
self._data.clear()
|
||||
self._ttls.clear()
|
||||
self._hash_data.clear()
|
||||
self._list_data.clear()
|
||||
return True
|
||||
|
||||
async def keys(self, pattern="*"):
|
||||
"""Get keys matching pattern."""
|
||||
with self._data_lock:
|
||||
# Clean expired keys
|
||||
now = time.time()
|
||||
expired_keys = [k for k, exp in self._ttls.items() if now > exp]
|
||||
for k in expired_keys:
|
||||
if k in self._data:
|
||||
del self._data[k]
|
||||
del self._ttls[k]
|
||||
|
||||
# Simple pattern matching
|
||||
result = []
|
||||
if pattern == "*":
|
||||
result = list(self._data.keys())
|
||||
elif pattern.endswith("*"):
|
||||
prefix = pattern[:-1]
|
||||
result = [k for k in self._data.keys() if k.startswith(prefix)]
|
||||
elif pattern.startswith("*"):
|
||||
suffix = pattern[1:]
|
||||
result = [k for k in self._data.keys() if k.endswith(suffix)]
|
||||
else:
|
||||
if pattern in self._data:
|
||||
result = [pattern]
|
||||
|
||||
return result
|
||||
|
||||
async def ping(self):
|
||||
"""Test connection."""
|
||||
return True
|
||||
|
||||
|
||||
class RedisCacheService:
|
||||
"""
|
||||
Cache service using Redis for storing A2A-related data.
|
||||
|
||||
This implementation uses a real Redis connection for distributed caching
|
||||
and data persistence across multiple instances.
|
||||
|
||||
If Redis is not available, falls back to an in-memory implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, redis_url: Optional[str] = None):
|
||||
"""
|
||||
Initialize the Redis cache service.
|
||||
|
||||
Args:
|
||||
redis_url: Redis server URL (optional, defaults to config value)
|
||||
"""
|
||||
if redis_url:
|
||||
self._redis_url = redis_url
|
||||
else:
|
||||
# Construir URL a partir dos componentes de configuração
|
||||
config = get_redis_config()
|
||||
protocol = "rediss" if config.get("ssl", False) else "redis"
|
||||
auth = f":{config['password']}@" if config.get("password") else ""
|
||||
self._redis_url = (
|
||||
f"{protocol}://{auth}{config['host']}:{config['port']}/{config['db']}"
|
||||
)
|
||||
|
||||
self._redis = None
|
||||
self._in_memory_mode = False
|
||||
self._connecting = False
|
||||
self._connection_lock = asyncio.Lock()
|
||||
logger.info(f"Initializing RedisCacheService with URL: {self._redis_url}")
|
||||
|
||||
async def _get_redis(self):
|
||||
"""
|
||||
Get a Redis connection, creating it if necessary.
|
||||
Falls back to in-memory implementation if Redis is not available.
|
||||
|
||||
Returns:
|
||||
Redis connection or in-memory fallback
|
||||
"""
|
||||
if self._redis is not None:
|
||||
return self._redis
|
||||
|
||||
async with self._connection_lock:
|
||||
if self._redis is None and not self._connecting:
|
||||
try:
|
||||
self._connecting = True
|
||||
logger.info(f"Connecting to Redis at {self._redis_url}")
|
||||
self._redis = aioredis.from_url(
|
||||
self._redis_url, encoding="utf-8", decode_responses=True
|
||||
)
|
||||
# Teste de conexão
|
||||
await self._redis.ping()
|
||||
logger.info("Redis connection successful")
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to Redis: {str(e)}")
|
||||
logger.warning(
|
||||
"Falling back to in-memory cache (not suitable for production)"
|
||||
)
|
||||
self._redis = _InMemoryCacheFallback()
|
||||
self._in_memory_mode = True
|
||||
finally:
|
||||
self._connecting = False
|
||||
|
||||
return self._redis
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
|
||||
"""
|
||||
Store a value in the cache.
|
||||
|
||||
Args:
|
||||
key: Key for the value
|
||||
value: Value to store
|
||||
ttl: Time to live in seconds (optional)
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
|
||||
# Convert dict/list to JSON string
|
||||
if isinstance(value, (dict, list)):
|
||||
value = json.dumps(value)
|
||||
|
||||
if ttl:
|
||||
await redis.setex(key, ttl, value)
|
||||
else:
|
||||
await redis.set(key, value)
|
||||
|
||||
logger.debug(f"Set cache key: {key}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting Redis key {key}: {str(e)}")
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Retrieve a value from the cache.
|
||||
|
||||
Args:
|
||||
key: Key for the value to retrieve
|
||||
|
||||
Returns:
|
||||
The stored value or None if not found
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
value = await redis.get(key)
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Try to parse as JSON
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
# Return as-is if not JSON
|
||||
return value
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Redis key {key}: {str(e)}")
|
||||
return None
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Remove a value from the cache.
|
||||
|
||||
Args:
|
||||
key: Key for the value to remove
|
||||
|
||||
Returns:
|
||||
True if the value was removed, False if it didn't exist
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
result = await redis.delete(key)
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting Redis key {key}: {str(e)}")
|
||||
return False
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if a key exists in the cache.
|
||||
|
||||
Args:
|
||||
key: Key to check
|
||||
|
||||
Returns:
|
||||
True if the key exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
return await redis.exists(key) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking Redis key {key}: {str(e)}")
|
||||
return False
|
||||
|
||||
async def set_hash(self, key: str, field: str, value: Any) -> None:
|
||||
"""
|
||||
Store a value in a hash.
|
||||
|
||||
Args:
|
||||
key: Hash key
|
||||
field: Hash field
|
||||
value: Value to store
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
|
||||
# Convert dict/list to JSON string
|
||||
if isinstance(value, (dict, list)):
|
||||
value = json.dumps(value)
|
||||
|
||||
await redis.hset(key, field, value)
|
||||
logger.debug(f"Set hash field: {key}:{field}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting Redis hash {key}:{field}: {str(e)}")
|
||||
|
||||
async def get_hash(self, key: str, field: str) -> Optional[Any]:
|
||||
"""
|
||||
Retrieve a value from a hash.
|
||||
|
||||
Args:
|
||||
key: Hash key
|
||||
field: Hash field
|
||||
|
||||
Returns:
|
||||
The stored value or None if not found
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
value = await redis.hget(key, field)
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Try to parse as JSON
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
# Return as-is if not JSON
|
||||
return value
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Redis hash {key}:{field}: {str(e)}")
|
||||
return None
|
||||
|
||||
async def delete_hash(self, key: str, field: str) -> bool:
|
||||
"""
|
||||
Remove a value from a hash.
|
||||
|
||||
Args:
|
||||
key: Hash key
|
||||
field: Hash field
|
||||
|
||||
Returns:
|
||||
True if the value was removed, False if it didn't exist
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
result = await redis.hdel(key, field)
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting Redis hash {key}:{field}: {str(e)}")
|
||||
return False
|
||||
|
||||
async def get_all_hash(self, key: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieve all values from a hash.
|
||||
|
||||
Args:
|
||||
key: Hash key
|
||||
|
||||
Returns:
|
||||
Dictionary with all hash values
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
result_dict = await redis.hgetall(key)
|
||||
|
||||
if not result_dict:
|
||||
return {}
|
||||
|
||||
# Try to parse each value as JSON
|
||||
parsed_dict = {}
|
||||
for field, value in result_dict.items():
|
||||
try:
|
||||
parsed_dict[field] = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
parsed_dict[field] = value
|
||||
|
||||
return parsed_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all Redis hash fields for {key}: {str(e)}")
|
||||
return {}
|
||||
|
||||
async def push_list(self, key: str, value: Any) -> int:
|
||||
"""
|
||||
Add a value to the end of a list.
|
||||
|
||||
Args:
|
||||
key: List key
|
||||
value: Value to add
|
||||
|
||||
Returns:
|
||||
Size of the list after the addition
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
|
||||
# Convert dict/list to JSON string
|
||||
if isinstance(value, (dict, list)):
|
||||
value = json.dumps(value)
|
||||
|
||||
return await redis.rpush(key, value)
|
||||
except Exception as e:
|
||||
logger.error(f"Error pushing to Redis list {key}: {str(e)}")
|
||||
return 0
|
||||
|
||||
async def get_list(self, key: str, start: int = 0, end: int = -1) -> List[Any]:
|
||||
"""
|
||||
Retrieve values from a list.
|
||||
|
||||
Args:
|
||||
key: List key
|
||||
start: Initial index (inclusive)
|
||||
end: Final index (inclusive), -1 for all
|
||||
|
||||
Returns:
|
||||
List with the retrieved values
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
values = await redis.lrange(key, start, end)
|
||||
|
||||
if not values:
|
||||
return []
|
||||
|
||||
# Try to parse each value as JSON
|
||||
result = []
|
||||
for value in values:
|
||||
try:
|
||||
result.append(json.loads(value))
|
||||
except json.JSONDecodeError:
|
||||
result.append(value)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Redis list {key}: {str(e)}")
|
||||
return []
|
||||
|
||||
async def expire(self, key: str, ttl: int) -> bool:
|
||||
"""
|
||||
Set a time-to-live for a key.
|
||||
|
||||
Args:
|
||||
key: Key
|
||||
ttl: Time-to-live in seconds
|
||||
|
||||
Returns:
|
||||
True if the key exists and the TTL was set, False otherwise
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
return await redis.expire(key, ttl)
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting expire for Redis key {key}: {str(e)}")
|
||||
return False
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""
|
||||
Clear the entire cache.
|
||||
|
||||
Warning: This is a destructive operation and will remove all data.
|
||||
Only use in development/test environments.
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
await redis.flushdb()
|
||||
logger.warning("Redis database flushed - all data cleared")
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing Redis database: {str(e)}")
|
||||
|
||||
async def keys(self, pattern: str = "*") -> List[str]:
|
||||
"""
|
||||
Retrieve keys that match a pattern.
|
||||
|
||||
Args:
|
||||
pattern: Glob pattern to filter keys
|
||||
|
||||
Returns:
|
||||
List of keys that match the pattern
|
||||
"""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
return await redis.keys(pattern)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Redis keys with pattern {pattern}: {str(e)}")
|
||||
return []
|
@ -3,12 +3,12 @@ import json
|
||||
from typing import AsyncGenerator, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
from datetime import datetime
|
||||
from ..schemas.streaming import (
|
||||
from src.schemas.streaming import (
|
||||
JSONRPCRequest,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
from ..services.agent_runner import run_agent
|
||||
from ..services.service_providers import (
|
||||
from src.services.agent_runner import run_agent
|
||||
from src.services.service_providers import (
|
||||
session_service,
|
||||
artifacts_service,
|
||||
memory_service,
|
||||
@ -25,31 +25,33 @@ class StreamingService:
|
||||
agent_id: str,
|
||||
api_key: str,
|
||||
message: str,
|
||||
contact_id: str = None,
|
||||
session_id: str = None,
|
||||
db: Session = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Inicia o streaming de eventos SSE para uma tarefa.
|
||||
Starts the SSE event streaming for a task.
|
||||
|
||||
Args:
|
||||
agent_id: ID do agente
|
||||
api_key: Chave de API para autenticação
|
||||
message: Mensagem inicial
|
||||
session_id: ID da sessão (opcional)
|
||||
db: Sessão do banco de dados
|
||||
agent_id: Agent ID
|
||||
api_key: API key for authentication
|
||||
message: Initial message
|
||||
contact_id: Contact ID (optional)
|
||||
session_id: Session ID (optional)
|
||||
db: Database session
|
||||
|
||||
Yields:
|
||||
Eventos SSE formatados
|
||||
Formatted SSE events
|
||||
"""
|
||||
# Validação básica da API key
|
||||
# Basic API key validation
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=401, detail="API key é obrigatória")
|
||||
raise HTTPException(status_code=401, detail="API key is required")
|
||||
|
||||
# Gera IDs únicos
|
||||
task_id = str(uuid.uuid4())
|
||||
# Generate unique IDs
|
||||
task_id = contact_id or str(uuid.uuid4())
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
# Monta payload JSON-RPC
|
||||
# Build JSON-RPC payload
|
||||
payload = JSONRPCRequest(
|
||||
id=request_id,
|
||||
params={
|
||||
@ -62,7 +64,7 @@ class StreamingService:
|
||||
},
|
||||
)
|
||||
|
||||
# Registra conexão
|
||||
# Register connection
|
||||
self.active_connections[task_id] = {
|
||||
"agent_id": agent_id,
|
||||
"api_key": api_key,
|
||||
@ -70,7 +72,7 @@ class StreamingService:
|
||||
}
|
||||
|
||||
try:
|
||||
# Envia evento de início
|
||||
# Send start event
|
||||
yield self._format_sse_event(
|
||||
"status",
|
||||
TaskStatusUpdateEvent(
|
||||
@ -80,10 +82,10 @@ class StreamingService:
|
||||
).model_dump_json(),
|
||||
)
|
||||
|
||||
# Executa o agente
|
||||
# Execute the agent
|
||||
result = await run_agent(
|
||||
str(agent_id),
|
||||
task_id,
|
||||
contact_id or task_id,
|
||||
message,
|
||||
session_service,
|
||||
artifacts_service,
|
||||
@ -92,7 +94,7 @@ class StreamingService:
|
||||
session_id,
|
||||
)
|
||||
|
||||
# Envia a resposta do agente como um evento separado
|
||||
# Send the agent's response as a separate event
|
||||
yield self._format_sse_event(
|
||||
"message",
|
||||
json.dumps(
|
||||
@ -104,7 +106,7 @@ class StreamingService:
|
||||
),
|
||||
)
|
||||
|
||||
# Evento de conclusão
|
||||
# Completion event
|
||||
yield self._format_sse_event(
|
||||
"status",
|
||||
TaskStatusUpdateEvent(
|
||||
@ -114,7 +116,7 @@ class StreamingService:
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Evento de erro
|
||||
# Error event
|
||||
yield self._format_sse_event(
|
||||
"status",
|
||||
TaskStatusUpdateEvent(
|
||||
@ -126,14 +128,14 @@ class StreamingService:
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Limpa conexão
|
||||
# Clean connection
|
||||
self.active_connections.pop(task_id, None)
|
||||
|
||||
def _format_sse_event(self, event_type: str, data: str) -> str:
|
||||
"""Formata um evento SSE."""
|
||||
"""Format an SSE event."""
|
||||
return f"event: {event_type}\ndata: {data}\n\n"
|
||||
|
||||
async def close_connection(self, task_id: str):
|
||||
"""Fecha uma conexão de streaming."""
|
||||
"""Close a streaming connection."""
|
||||
if task_id in self.active_connections:
|
||||
self.active_connections.pop(task_id)
|
||||
|
110
src/utils/a2a_utils.py
Normal file
110
src/utils/a2a_utils.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""
|
||||
A2A protocol utilities.
|
||||
|
||||
This module contains utility functions for the A2A protocol implementation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Any, Dict
|
||||
from src.schemas.a2a import (
|
||||
ContentTypeNotSupportedError,
|
||||
UnsupportedOperationError,
|
||||
JSONRPCResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def are_modalities_compatible(
|
||||
server_output_modes: Optional[List[str]], client_output_modes: Optional[List[str]]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if server and client modalities are compatible.
|
||||
|
||||
Modalities are compatible if they are both non-empty
|
||||
and there is at least one common element.
|
||||
|
||||
Args:
|
||||
server_output_modes: List of output modes supported by the server
|
||||
client_output_modes: List of output modes requested by the client
|
||||
|
||||
Returns:
|
||||
True if compatible, False otherwise
|
||||
"""
|
||||
# If client doesn't specify modes, assume all are accepted
|
||||
if client_output_modes is None or len(client_output_modes) == 0:
|
||||
return True
|
||||
|
||||
# If server doesn't specify modes, assume all are supported
|
||||
if server_output_modes is None or len(server_output_modes) == 0:
|
||||
return True
|
||||
|
||||
# Check if there's at least one common mode
|
||||
return any(mode in server_output_modes for mode in client_output_modes)
|
||||
|
||||
|
||||
def new_incompatible_types_error(request_id: str) -> JSONRPCResponse:
|
||||
"""
|
||||
Create a JSON-RPC response for incompatible content types error.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request that caused the error
|
||||
|
||||
Returns:
|
||||
JSON-RPC response with ContentTypeNotSupportedError
|
||||
"""
|
||||
return JSONRPCResponse(id=request_id, error=ContentTypeNotSupportedError())
|
||||
|
||||
|
||||
def new_not_implemented_error(request_id: str) -> JSONRPCResponse:
|
||||
"""
|
||||
Create a JSON-RPC response for unimplemented operation error.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request that caused the error
|
||||
|
||||
Returns:
|
||||
JSON-RPC response with UnsupportedOperationError
|
||||
"""
|
||||
return JSONRPCResponse(id=request_id, error=UnsupportedOperationError())
|
||||
|
||||
|
||||
def create_task_id(agent_id: str, session_id: str, timestamp: str = None) -> str:
|
||||
"""
|
||||
Create a unique task ID for an agent and session.
|
||||
|
||||
Args:
|
||||
agent_id: The ID of the agent
|
||||
session_id: The ID of the session
|
||||
timestamp: Optional timestamp to include in the ID
|
||||
|
||||
Returns:
|
||||
A unique task ID
|
||||
"""
|
||||
import uuid
|
||||
import time
|
||||
|
||||
timestamp = timestamp or str(int(time.time()))
|
||||
unique = uuid.uuid4().hex[:8]
|
||||
|
||||
return f"{agent_id}_{session_id}_{timestamp}_{unique}"
|
||||
|
||||
|
||||
def format_error_response(error: Exception, request_id: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Format an exception as a JSON-RPC error response.
|
||||
|
||||
Args:
|
||||
error: The exception to format
|
||||
request_id: The ID of the request that caused the error
|
||||
|
||||
Returns:
|
||||
JSON-RPC error response as dictionary
|
||||
"""
|
||||
from src.schemas.a2a import InternalError, JSONRPCResponse
|
||||
|
||||
error_response = JSONRPCResponse(
|
||||
id=request_id, error=InternalError(message=str(error))
|
||||
)
|
||||
|
||||
return error_response.model_dump(exclude_none=True)
|
Loading…
Reference in New Issue
Block a user