From 34734b6da7182887eafc4ccb89b48cf0eac6b6ad Mon Sep 17 00:00:00 2001 From: Davidson Gomes Date: Tue, 29 Apr 2025 20:15:02 -0300 Subject: [PATCH] feat(api): implement WebSocket support for real-time task updates and add streaming service --- .env | 19 ++-- planejamento_atualizado.md | 191 +++++++++++++++++++++++++++++++++++ pyproject.toml | 2 + src/api/chat_routes.py | 111 +++++++++++++++++++- src/core/jwt_middleware.py | 14 +++ src/main.py | 9 +- src/services/agent_runner.py | 96 +++++++++++++++++- static/test.html | 187 ++++++++++++++++++++++++++++++++++ 8 files changed, 613 insertions(+), 16 deletions(-) create mode 100644 planejamento_atualizado.md create mode 100644 static/test.html diff --git a/.env b/.env index ebe59e39..e569655a 100644 --- a/.env +++ b/.env @@ -6,28 +6,23 @@ API_URL="http://localhost:8000" ORGANIZATION_NAME="Evo AI" ORGANIZATION_URL="https://evoai.evoapicloud.com" -# Configurações do banco de dados +# Database settings POSTGRES_CONNECTION_STRING="postgresql://postgres:root@localhost:5432/evo_ai" -# Configurações de logging +# Logging settings LOG_LEVEL="INFO" LOG_DIR="logs" -# Configurações da API de Conhecimento -KNOWLEDGE_API_URL="http://localhost:5540" -KNOWLEDGE_API_KEY="sua-chave-api-conhecimento" -TENANT_ID="seu-tenant-id" - -# Configurações do Redis +# Redis settings REDIS_HOST="localhost" REDIS_PORT=6379 REDIS_DB=8 REDIS_PASSWORD="" -# TTL do cache de ferramentas em segundos (1 hora) +# Tools cache TTL in seconds (1 hour) TOOLS_CACHE_TTL=3600 -# Configurações JWT +# JWT settings JWT_SECRET_KEY="f6884ef5be4c279686ff90f0ed9d4656685eef9807245019ac94a3fbe32b0938" JWT_ALGORITHM="HS256" JWT_EXPIRATION_TIME=3600 @@ -37,12 +32,12 @@ SENDGRID_API_KEY="SG.lfmOfb13QseRA0AHTLlKlw.H9RX5wKx37URMPohaAU1D4tJimG4g0FPR2iU EMAIL_FROM="noreply@evolution-api.com" APP_URL="https://evoai.evoapicloud.com" -# Configurações do Servidor +# Server settings HOST="0.0.0.0" PORT=8000 DEBUG=false -# Configurações de Seeders +# Seeders settings ADMIN_EMAIL="admin@evoai.com" ADMIN_INITIAL_PASSWORD="senhaforte123" DEMO_EMAIL="demo@exemplo.com" diff --git a/planejamento_atualizado.md b/planejamento_atualizado.md new file mode 100644 index 00000000..00a5a78d --- /dev/null +++ b/planejamento_atualizado.md @@ -0,0 +1,191 @@ +# Planejamento de Implementação - A2A Streaming (Atualizado) + +## 1. Visão Geral + +Implementar suporte a Server-Sent Events (SSE) para streaming de atualizações de tarefas em tempo real, seguindo a especificação oficial do A2A. + +## 2. Componentes Necessários + +### 2.2 Estrutura de Arquivos + +``` +src/ +├── api/ +│ └── agent_routes.py (modificação) +├── schemas/ +│ └── streaming.py (novo) +├── services/ +│ └── streaming_service.py (novo) +└── utils/ + └── streaming.py (novo) +``` + +## 3. Implementação + +### 3.1 Schemas (Pydantic) + +```python +# schemas/streaming.py +- TaskStatusUpdateEvent + - state: str (working, completed, failed) + - timestamp: datetime + - message: Optional[Message] + - error: Optional[Error] + +- TaskArtifactUpdateEvent + - type: str + - content: str + - metadata: Dict[str, Any] + +- JSONRPCRequest + - jsonrpc: str = "2.0" + - id: str + - method: str = "tasks/sendSubscribe" + - params: Dict[str, Any] + +- Message + - role: str + - parts: List[MessagePart] + +- MessagePart + - type: str + - text: str +``` + +### 3.2 Serviço de Streaming + +````python +# services/streaming_service.py +- send_task_streaming() + - Monta payload JSON-RPC conforme especificação: + ```json + { + "jsonrpc": "2.0", + "id": "", + "method": "tasks/sendSubscribe", + "params": { + "id": "", + "sessionId": "", + "message": { + "role": "user", + "parts": [{"type": "text", "text": ""}] + } + } + } + ``` + - Configura headers: + - Accept: text/event-stream + - Authorization: x-api-key + - Gerencia conexão SSE + - Processa eventos em tempo real +```` + +### 3.3 Rota de Streaming + +```python +# api/agent_routes.py +- Nova rota POST /{agent_id}/tasks/sendSubscribe + - Validação de API key + - Gerenciamento de sessão + - Streaming de eventos SSE + - Tratamento de erros JSON-RPC +``` + +### 3.4 Utilitários + +```python +# utils/streaming.py +- Helpers para SSE + - Formatação de eventos + - Tratamento de reconexão + - Timeout e retry +- Processamento de eventos + - Parsing de eventos SSE + - Validação de payloads +- Formatação de respostas + - Conformidade com JSON-RPC 2.0 +``` + +## 4. Fluxo de Dados + +1. Cliente envia requisição JSON-RPC para `/tasks/sendSubscribe` +2. Servidor valida API key e configura sessão +3. Inicia streaming de eventos SSE +4. Envia atualizações em tempo real: + - TaskStatusUpdateEvent (estado da tarefa) + - TaskArtifactUpdateEvent (artefatos gerados) + - Mensagens do histórico + +## 5. Exemplo de Uso + +```python +async def exemplo_uso(): + agent_id = "uuid-do-agente" + api_key = "sua-api-key" + mensagem = "Olá, como posso ajudar?" + + async with httpx.AsyncClient() as client: + # Configura headers + headers = { + "Accept": "text/event-stream", + "Authorization": f"x-api-key {api_key}" + } + + # Monta payload JSON-RPC + payload = { + "jsonrpc": "2.0", + "id": str(uuid.uuid4()), + "method": "tasks/sendSubscribe", + "params": { + "id": str(uuid.uuid4()), + "message": { + "role": "user", + "parts": [{"type": "text", "text": mensagem}] + } + } + } + + # Inicia streaming + async with connect_sse(client, "POST", f"/agents/{agent_id}/tasks/sendSubscribe", + json=payload, headers=headers) as event_source: + async for event in event_source.aiter_sse(): + if event.event == "message": + data = json.loads(event.data) + print(f"Evento recebido: {data}") +``` + +## 6. Considerações de Segurança + +- Validação rigorosa de API keys +- Timeout de conexão SSE (30 segundos) +- Tratamento de erros e reconexão automática +- Limites de taxa (rate limiting) +- Validação de payloads JSON-RPC +- Sanitização de inputs + +## 7. Testes + +- Testes unitários para schemas +- Testes de integração para streaming +- Testes de carga e performance +- Testes de reconexão e resiliência +- Testes de conformidade JSON-RPC + +## 8. Documentação + +- Atualizar documentação da API +- Adicionar exemplos de uso +- Documentar formatos de eventos +- Guia de troubleshooting +- Referência à especificação A2A + +## 9. Próximos Passos + +1. Implementar schemas Pydantic conforme especificação +2. Desenvolver serviço de streaming com suporte a JSON-RPC +3. Adicionar rota SSE com validação de payloads +4. Implementar utilitários de streaming +5. Escrever testes de conformidade +6. Atualizar documentação +7. Revisão de código +8. Deploy em ambiente de teste diff --git a/pyproject.toml b/pyproject.toml index 2df64356..5a1b2682 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,8 @@ dependencies = [ "bcrypt==4.3.0", "jinja2==3.1.6", "pydantic[email]==2.11.3", + "httpx==0.28.1", + "httpx-sse==0.4.0", ] [project.optional-dependencies] diff --git a/src/api/chat_routes.py b/src/api/chat_routes.py index 2b88e871..be8746d3 100644 --- a/src/api/chat_routes.py +++ b/src/api/chat_routes.py @@ -1,15 +1,23 @@ -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import ( + APIRouter, + Depends, + HTTPException, + status, + WebSocket, + WebSocketDisconnect, +) from sqlalchemy.orm import Session from src.config.database import get_db from src.core.jwt_middleware import ( get_jwt_token, verify_user_client, + get_jwt_token_ws, ) from src.services import ( agent_service, ) from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse -from src.services.agent_runner import run_agent +from src.services.agent_runner import run_agent, run_agent_stream from src.core.exceptions import AgentNotFoundError from src.services.service_providers import ( session_service, @@ -19,6 +27,8 @@ from src.services.service_providers import ( from datetime import datetime import logging +import json +from fastapi.responses import StreamingResponse logger = logging.getLogger(__name__) @@ -29,6 +39,103 @@ router = APIRouter( ) +@router.websocket("/ws/{agent_id}/{contact_id}") +async def websocket_chat( + websocket: WebSocket, + agent_id: str, + contact_id: str, + db: Session = Depends(get_db), +): + try: + # Accept the connection + await websocket.accept() + logger.info("WebSocket connection accepted, waiting for authentication") + + # Aguardar mensagem de autenticação + try: + auth_data = await websocket.receive_json() + logger.info(f"Received authentication data: {auth_data}") + + if not auth_data.get("type") == "authorization" or not auth_data.get( + "token" + ): + logger.warning("Invalid authentication message") + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + return + + token = auth_data["token"] + # Verify the token + payload = await get_jwt_token_ws(token) + if not payload: + logger.warning("Invalid token") + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + return + + # Verificar se o agente pertence ao cliente do usuário + agent = agent_service.get_agent(db, agent_id) + if not agent: + logger.warning(f"Agent {agent_id} not found") + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + return + + # Verificar se o usuário tem acesso ao agente (via client) + await verify_user_client(payload, db, agent.client_id) + + logger.info( + f"WebSocket connection established for agent {agent_id} and contact {contact_id}" + ) + + while True: + try: + data = await websocket.receive_json() + logger.info(f"Received message: {data}") + message = data.get("message") + + if not message: + continue + + async for chunk in run_agent_stream( + agent_id=agent_id, + contact_id=contact_id, + message=message, + session_service=session_service, + artifacts_service=artifacts_service, + memory_service=memory_service, + db=db, + ): + # Enviar cada chunk como uma mensagem JSON + await websocket.send_json( + {"message": chunk, "turn_complete": False} + ) + + # Enviar sinal de turno completo + await websocket.send_json({"message": "", "turn_complete": True}) + + except WebSocketDisconnect: + logger.info("Client disconnected") + break + except json.JSONDecodeError: + logger.warning("Invalid JSON message received") + continue + except Exception as e: + logger.error(f"Error in WebSocket message handling: {str(e)}") + await websocket.close(code=status.WS_1011_INTERNAL_ERROR) + break + + except WebSocketDisconnect: + logger.info("Client disconnected during authentication") + except json.JSONDecodeError: + logger.warning("Invalid authentication message format") + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + except Exception as e: + logger.error(f"Error during authentication: {str(e)}") + await websocket.close(code=status.WS_1011_INTERNAL_ERROR) + + except Exception as e: + logger.error(f"WebSocket error: {str(e)}") + await websocket.close(code=status.WS_1011_INTERNAL_ERROR) + + @router.post( "/", response_model=ChatResponse, diff --git a/src/core/jwt_middleware.py b/src/core/jwt_middleware.py index 918e36ba..4a313548 100644 --- a/src/core/jwt_middleware.py +++ b/src/core/jwt_middleware.py @@ -149,3 +149,17 @@ def get_current_user_client_id( return UUID(client_id) return None + + +async def get_jwt_token_ws(token: str) -> Optional[dict]: + """ + Verifica e decodifica o token JWT para WebSocket. + Retorna o payload se o token for válido, None caso contrário. + """ + try: + payload = jwt.decode( + token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] + ) + return payload + except JWTError: + return None diff --git a/src/main.py b/src/main.py index 0645f172..6d862b63 100644 --- a/src/main.py +++ b/src/main.py @@ -3,6 +3,7 @@ import sys from pathlib import Path from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles from src.config.database import engine, Base from src.config.settings import settings from src.utils.logger import setup_logger @@ -39,12 +40,18 @@ app = FastAPI( # CORS configuration app.add_middleware( CORSMiddleware, - allow_origins=settings.CORS_ORIGINS, + allow_origins=["*"], # Permite todas as origens em desenvolvimento allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) +# Configuração de arquivos estáticos +static_dir = Path("static") +if not static_dir.exists(): + static_dir.mkdir(parents=True) +app.mount("/static", StaticFiles(directory=static_dir), name="static") + # PostgreSQL configuration POSTGRES_CONNECTION_STRING = os.getenv( "POSTGRES_CONNECTION_STRING", "postgresql://postgres:root@localhost:5432/evo_ai" diff --git a/src/services/agent_runner.py b/src/services/agent_runner.py index 1abe17f0..0f3f7a38 100644 --- a/src/services/agent_runner.py +++ b/src/services/agent_runner.py @@ -8,7 +8,8 @@ from src.core.exceptions import AgentNotFoundError, InternalServerError from src.services.agent_service import get_agent from src.services.agent_builder import AgentBuilder from sqlalchemy.orm import Session -from typing import Optional +from typing import Optional, AsyncGenerator +import asyncio logger = setup_logger(__name__) @@ -101,3 +102,96 @@ async def run_agent( except Exception as e: logger.error(f"Internal error processing request: {str(e)}", exc_info=True) raise InternalServerError(str(e)) + + +async def run_agent_stream( + agent_id: str, + contact_id: str, + message: str, + session_service: DatabaseSessionService, + artifacts_service: InMemoryArtifactService, + memory_service: InMemoryMemoryService, + db: Session, + session_id: Optional[str] = None, +) -> AsyncGenerator[str, None]: + try: + logger.info( + f"Starting streaming execution of agent {agent_id} for contact {contact_id}" + ) + logger.info(f"Received message: {message}") + + get_root_agent = get_agent(db, agent_id) + logger.info( + f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})" + ) + + if get_root_agent is None: + raise AgentNotFoundError(f"Agent with ID {agent_id} not found") + + # Using the AgentBuilder to create the agent + agent_builder = AgentBuilder(db) + root_agent, exit_stack = await agent_builder.build_agent(get_root_agent) + + logger.info("Configuring Runner") + agent_runner = Runner( + agent=root_agent, + app_name=agent_id, + session_service=session_service, + artifact_service=artifacts_service, + memory_service=memory_service, + ) + adk_session_id = contact_id + "_" + agent_id + if session_id is None: + session_id = adk_session_id + + logger.info(f"Searching session for contact {contact_id}") + session = session_service.get_session( + app_name=agent_id, + user_id=contact_id, + session_id=adk_session_id, + ) + + if session is None: + logger.info(f"Creating new session for contact {contact_id}") + session = session_service.create_session( + app_name=agent_id, + user_id=contact_id, + session_id=adk_session_id, + ) + + content = Content(role="user", parts=[Part(text=message)]) + logger.info("Starting agent streaming execution") + + try: + for event in agent_runner.run( + user_id=contact_id, + session_id=adk_session_id, + new_message=content, + ): + if event.content and event.content.parts: + text = event.content.parts[0].text + if text: + yield text + await asyncio.sleep(0) # Allow other tasks to run + + completed_session = session_service.get_session( + app_name=agent_id, + user_id=contact_id, + session_id=adk_session_id, + ) + + memory_service.add_session_to_memory(completed_session) + + finally: + # Ensure the exit_stack is closed correctly + if exit_stack: + await exit_stack.aclose() + + logger.info("Agent streaming execution completed successfully") + + except AgentNotFoundError as e: + logger.error(f"Error processing request: {str(e)}") + raise e + except Exception as e: + logger.error(f"Internal error processing request: {str(e)}", exc_info=True) + raise InternalServerError(str(e)) diff --git a/static/test.html b/static/test.html new file mode 100644 index 00000000..b7a44d90 --- /dev/null +++ b/static/test.html @@ -0,0 +1,187 @@ + + + + + ADK Streaming Test + + + + +

ADK Streaming Test

+
+
+
+
+
+
+

+ + +
+ +
+ + + + + \ No newline at end of file