185 lines
5.9 KiB
Python
185 lines
5.9 KiB
Python
from fastapi import (
|
|
APIRouter,
|
|
Depends,
|
|
HTTPException,
|
|
status,
|
|
WebSocket,
|
|
WebSocketDisconnect,
|
|
)
|
|
from sqlalchemy.orm import Session
|
|
from src.config.database import get_db
|
|
from src.core.jwt_middleware import (
|
|
get_jwt_token,
|
|
verify_user_client,
|
|
get_jwt_token_ws,
|
|
)
|
|
from src.services import (
|
|
agent_service,
|
|
)
|
|
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse
|
|
from src.services.agent_runner import run_agent, run_agent_stream
|
|
from src.core.exceptions import AgentNotFoundError
|
|
from src.services.service_providers import (
|
|
session_service,
|
|
artifacts_service,
|
|
memory_service,
|
|
)
|
|
|
|
from datetime import datetime
|
|
import logging
|
|
import json
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(
|
|
prefix="/chat",
|
|
tags=["chat"],
|
|
responses={404: {"description": "Not found"}},
|
|
)
|
|
|
|
|
|
@router.websocket("/ws/{agent_id}/{contact_id}")
|
|
async def websocket_chat(
|
|
websocket: WebSocket,
|
|
agent_id: str,
|
|
contact_id: str,
|
|
db: Session = Depends(get_db),
|
|
):
|
|
try:
|
|
# Accept the connection
|
|
await websocket.accept()
|
|
logger.info("WebSocket connection accepted, waiting for authentication")
|
|
|
|
# Aguardar mensagem de autenticação
|
|
try:
|
|
auth_data = await websocket.receive_json()
|
|
logger.info(f"Received authentication data: {auth_data}")
|
|
|
|
if not auth_data.get("type") == "authorization" or not auth_data.get(
|
|
"token"
|
|
):
|
|
logger.warning("Invalid authentication message")
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
return
|
|
|
|
token = auth_data["token"]
|
|
# Verify the token
|
|
payload = await get_jwt_token_ws(token)
|
|
if not payload:
|
|
logger.warning("Invalid token")
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
return
|
|
|
|
# Verificar se o agente pertence ao cliente do usuário
|
|
agent = agent_service.get_agent(db, agent_id)
|
|
if not agent:
|
|
logger.warning(f"Agent {agent_id} not found")
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
return
|
|
|
|
# Verificar se o usuário tem acesso ao agente (via client)
|
|
await verify_user_client(payload, db, agent.client_id)
|
|
|
|
logger.info(
|
|
f"WebSocket connection established for agent {agent_id} and contact {contact_id}"
|
|
)
|
|
|
|
while True:
|
|
try:
|
|
data = await websocket.receive_json()
|
|
logger.info(f"Received message: {data}")
|
|
message = data.get("message")
|
|
|
|
if not message:
|
|
continue
|
|
|
|
async for chunk in run_agent_stream(
|
|
agent_id=agent_id,
|
|
contact_id=contact_id,
|
|
message=message,
|
|
session_service=session_service,
|
|
artifacts_service=artifacts_service,
|
|
memory_service=memory_service,
|
|
db=db,
|
|
):
|
|
# Enviar cada chunk como uma mensagem JSON
|
|
await websocket.send_json(
|
|
{"message": chunk, "turn_complete": False}
|
|
)
|
|
|
|
# Enviar sinal de turno completo
|
|
await websocket.send_json({"message": "", "turn_complete": True})
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info("Client disconnected")
|
|
break
|
|
except json.JSONDecodeError:
|
|
logger.warning("Invalid JSON message received")
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"Error in WebSocket message handling: {str(e)}")
|
|
await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
|
|
break
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info("Client disconnected during authentication")
|
|
except json.JSONDecodeError:
|
|
logger.warning("Invalid authentication message format")
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
except Exception as e:
|
|
logger.error(f"Error during authentication: {str(e)}")
|
|
await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket error: {str(e)}")
|
|
await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
|
|
|
|
|
|
@router.post(
|
|
"/",
|
|
response_model=ChatResponse,
|
|
responses={
|
|
400: {"model": ErrorResponse},
|
|
404: {"model": ErrorResponse},
|
|
500: {"model": ErrorResponse},
|
|
},
|
|
)
|
|
async def chat(
|
|
request: ChatRequest,
|
|
db: Session = Depends(get_db),
|
|
payload: dict = Depends(get_jwt_token),
|
|
):
|
|
# Verify if the agent belongs to the user's client
|
|
agent = agent_service.get_agent(db, request.agent_id)
|
|
if not agent:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found"
|
|
)
|
|
|
|
# Verify if the user has access to the agent (via client)
|
|
await verify_user_client(payload, db, agent.client_id)
|
|
|
|
try:
|
|
final_response_text = await run_agent(
|
|
request.agent_id,
|
|
request.contact_id,
|
|
request.message,
|
|
session_service,
|
|
artifacts_service,
|
|
memory_service,
|
|
db,
|
|
)
|
|
|
|
return {
|
|
"response": final_response_text,
|
|
"status": "success",
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
except AgentNotFoundError as e:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
|
)
|