evo-ai/src/api/chat_routes.py

185 lines
5.9 KiB
Python

from fastapi import (
APIRouter,
Depends,
HTTPException,
status,
WebSocket,
WebSocketDisconnect,
)
from sqlalchemy.orm import Session
from src.config.database import get_db
from src.core.jwt_middleware import (
get_jwt_token,
verify_user_client,
get_jwt_token_ws,
)
from src.services import (
agent_service,
)
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse
from src.services.agent_runner import run_agent, run_agent_stream
from src.core.exceptions import AgentNotFoundError
from src.services.service_providers import (
session_service,
artifacts_service,
memory_service,
)
from datetime import datetime
import logging
import json
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/chat",
tags=["chat"],
responses={404: {"description": "Not found"}},
)
@router.websocket("/ws/{agent_id}/{contact_id}")
async def websocket_chat(
websocket: WebSocket,
agent_id: str,
contact_id: str,
db: Session = Depends(get_db),
):
try:
# Accept the connection
await websocket.accept()
logger.info("WebSocket connection accepted, waiting for authentication")
# Aguardar mensagem de autenticação
try:
auth_data = await websocket.receive_json()
logger.info(f"Received authentication data: {auth_data}")
if not auth_data.get("type") == "authorization" or not auth_data.get(
"token"
):
logger.warning("Invalid authentication message")
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
token = auth_data["token"]
# Verify the token
payload = await get_jwt_token_ws(token)
if not payload:
logger.warning("Invalid token")
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
# Verificar se o agente pertence ao cliente do usuário
agent = agent_service.get_agent(db, agent_id)
if not agent:
logger.warning(f"Agent {agent_id} not found")
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
# Verificar se o usuário tem acesso ao agente (via client)
await verify_user_client(payload, db, agent.client_id)
logger.info(
f"WebSocket connection established for agent {agent_id} and contact {contact_id}"
)
while True:
try:
data = await websocket.receive_json()
logger.info(f"Received message: {data}")
message = data.get("message")
if not message:
continue
async for chunk in run_agent_stream(
agent_id=agent_id,
contact_id=contact_id,
message=message,
session_service=session_service,
artifacts_service=artifacts_service,
memory_service=memory_service,
db=db,
):
# Enviar cada chunk como uma mensagem JSON
await websocket.send_json(
{"message": chunk, "turn_complete": False}
)
# Enviar sinal de turno completo
await websocket.send_json({"message": "", "turn_complete": True})
except WebSocketDisconnect:
logger.info("Client disconnected")
break
except json.JSONDecodeError:
logger.warning("Invalid JSON message received")
continue
except Exception as e:
logger.error(f"Error in WebSocket message handling: {str(e)}")
await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
break
except WebSocketDisconnect:
logger.info("Client disconnected during authentication")
except json.JSONDecodeError:
logger.warning("Invalid authentication message format")
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
except Exception as e:
logger.error(f"Error during authentication: {str(e)}")
await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
except Exception as e:
logger.error(f"WebSocket error: {str(e)}")
await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
@router.post(
"/",
response_model=ChatResponse,
responses={
400: {"model": ErrorResponse},
404: {"model": ErrorResponse},
500: {"model": ErrorResponse},
},
)
async def chat(
request: ChatRequest,
db: Session = Depends(get_db),
payload: dict = Depends(get_jwt_token),
):
# Verify if the agent belongs to the user's client
agent = agent_service.get_agent(db, request.agent_id)
if not agent:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found"
)
# Verify if the user has access to the agent (via client)
await verify_user_client(payload, db, agent.client_id)
try:
final_response_text = await run_agent(
request.agent_id,
request.contact_id,
request.message,
session_service,
artifacts_service,
memory_service,
db,
)
return {
"response": final_response_text,
"status": "success",
"timestamp": datetime.now().isoformat(),
}
except AgentNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)