303 lines
12 KiB
Python
303 lines
12 KiB
Python
"""
|
|
┌──────────────────────────────────────────────────────────────────────────────┐
|
|
│ @author: Davidson Gomes │
|
|
│ @file: chat_routes.py │
|
|
│ Developed by: Davidson Gomes │
|
|
│ Creation date: May 13, 2025 │
|
|
│ Contact: contato@evolution-api.com │
|
|
├──────────────────────────────────────────────────────────────────────────────┤
|
|
│ @copyright © Evolution API 2025. All rights reserved. │
|
|
│ Licensed under the Apache License, Version 2.0 │
|
|
│ │
|
|
│ You may not use this file except in compliance with the License. │
|
|
│ You may obtain a copy of the License at │
|
|
│ │
|
|
│ http://www.apache.org/licenses/LICENSE-2.0 │
|
|
│ │
|
|
│ Unless required by applicable law or agreed to in writing, software │
|
|
│ distributed under the License is distributed on an "AS IS" BASIS, │
|
|
│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │
|
|
│ See the License for the specific language governing permissions and │
|
|
│ limitations under the License. │
|
|
├──────────────────────────────────────────────────────────────────────────────┤
|
|
│ @important │
|
|
│ For any future changes to the code in this file, it is recommended to │
|
|
│ include, together with the modification, the information of the developer │
|
|
│ who changed it and the date of modification. │
|
|
└──────────────────────────────────────────────────────────────────────────────┘
|
|
"""
|
|
|
|
import uuid
|
|
import base64
|
|
from fastapi import (
|
|
APIRouter,
|
|
Depends,
|
|
HTTPException,
|
|
status,
|
|
WebSocket,
|
|
WebSocketDisconnect,
|
|
Header,
|
|
)
|
|
from sqlalchemy.orm import Session
|
|
from src.config.database import get_db
|
|
from src.core.jwt_middleware import (
|
|
get_jwt_token,
|
|
verify_user_client,
|
|
get_jwt_token_ws,
|
|
)
|
|
from src.services import (
|
|
agent_service,
|
|
)
|
|
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse, FileData
|
|
from src.services.agent_runner import run_agent, run_agent_stream
|
|
from src.core.exceptions import AgentNotFoundError
|
|
from src.services.service_providers import (
|
|
session_service,
|
|
artifacts_service,
|
|
memory_service,
|
|
)
|
|
|
|
from datetime import datetime
|
|
import logging
|
|
import json
|
|
from typing import Optional, Dict, List, Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(
|
|
prefix="/chat",
|
|
tags=["chat"],
|
|
responses={404: {"description": "Not found"}},
|
|
)
|
|
|
|
|
|
async def get_agent_by_api_key(
|
|
agent_id: str,
|
|
api_key: Optional[str] = Header(None, alias="x-api-key"),
|
|
authorization: Optional[str] = Header(None),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
"""Flexible authentication for chat routes, allowing JWT or API key"""
|
|
if authorization:
|
|
# Try to authenticate with JWT token first
|
|
try:
|
|
# Extract token from Authorization header if needed
|
|
token = (
|
|
authorization.replace("Bearer ", "")
|
|
if authorization.startswith("Bearer ")
|
|
else authorization
|
|
)
|
|
payload = await get_jwt_token(token)
|
|
agent = agent_service.get_agent(db, agent_id)
|
|
if not agent:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Agent not found",
|
|
)
|
|
|
|
# Verify if the user has access to the agent's client
|
|
await verify_user_client(payload, db, agent.client_id)
|
|
return agent
|
|
except Exception as e:
|
|
logger.warning(f"JWT authentication failed: {str(e)}")
|
|
# If JWT fails, continue to try with API key
|
|
|
|
# Try to authenticate with API key
|
|
if not api_key:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Authentication required (JWT or API key)",
|
|
)
|
|
|
|
agent = agent_service.get_agent(db, agent_id)
|
|
if not agent or not agent.config:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found"
|
|
)
|
|
|
|
# Verify if the API key matches
|
|
if not agent.config.get("api_key") or agent.config.get("api_key") != api_key:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
|
)
|
|
|
|
return agent
|
|
|
|
|
|
@router.websocket("/ws/{agent_id}/{external_id}")
|
|
async def websocket_chat(
|
|
websocket: WebSocket,
|
|
agent_id: str,
|
|
external_id: str,
|
|
db: Session = Depends(get_db),
|
|
):
|
|
try:
|
|
# Accept the connection
|
|
await websocket.accept()
|
|
logger.info("WebSocket connection accepted, waiting for authentication")
|
|
|
|
# Wait for authentication message
|
|
try:
|
|
auth_data = await websocket.receive_json()
|
|
logger.info(f"Authentication data received: {auth_data}")
|
|
|
|
if not (
|
|
auth_data.get("type") == "authorization"
|
|
and (auth_data.get("token") or auth_data.get("api_key"))
|
|
):
|
|
logger.warning("Invalid authentication message")
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
return
|
|
|
|
# Verify if the agent exists
|
|
agent = agent_service.get_agent(db, agent_id)
|
|
if not agent:
|
|
logger.warning(f"Agent {agent_id} not found")
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
return
|
|
|
|
# Verify authentication
|
|
is_authenticated = False
|
|
|
|
# Try with JWT token
|
|
if auth_data.get("token"):
|
|
try:
|
|
payload = await get_jwt_token_ws(auth_data["token"])
|
|
if payload:
|
|
# Verify if the user has access to the agent
|
|
await verify_user_client(payload, db, agent.client_id)
|
|
is_authenticated = True
|
|
except Exception as e:
|
|
logger.warning(f"JWT authentication failed: {str(e)}")
|
|
|
|
# If JWT fails, try with API key
|
|
if not is_authenticated and auth_data.get("api_key"):
|
|
if agent.config and agent.config.get("api_key") == auth_data.get(
|
|
"api_key"
|
|
):
|
|
is_authenticated = True
|
|
else:
|
|
logger.warning("Invalid API key")
|
|
|
|
if not is_authenticated:
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
return
|
|
|
|
logger.info(
|
|
f"WebSocket connection established for agent {agent_id} and external_id {external_id}"
|
|
)
|
|
|
|
while True:
|
|
try:
|
|
data = await websocket.receive_json()
|
|
logger.info(f"Received message: {data}")
|
|
message = data.get("message")
|
|
|
|
if not message:
|
|
continue
|
|
|
|
files = None
|
|
if data.get("files") and isinstance(data.get("files"), list):
|
|
try:
|
|
files = []
|
|
for file_data in data.get("files"):
|
|
if (
|
|
isinstance(file_data, dict)
|
|
and file_data.get("filename")
|
|
and file_data.get("content_type")
|
|
and file_data.get("data")
|
|
):
|
|
files.append(
|
|
FileData(
|
|
filename=file_data.get("filename"),
|
|
content_type=file_data.get("content_type"),
|
|
data=file_data.get("data"),
|
|
)
|
|
)
|
|
logger.info(f"Processed {len(files)} files via WebSocket")
|
|
except Exception as e:
|
|
logger.error(f"Error processing files: {str(e)}")
|
|
files = None
|
|
|
|
async for chunk in run_agent_stream(
|
|
agent_id=agent_id,
|
|
external_id=external_id,
|
|
message=message,
|
|
session_service=session_service,
|
|
artifacts_service=artifacts_service,
|
|
memory_service=memory_service,
|
|
db=db,
|
|
files=files,
|
|
):
|
|
await websocket.send_json(
|
|
{"message": json.loads(chunk), "turn_complete": False}
|
|
)
|
|
|
|
# Send signal of complete turn
|
|
await websocket.send_json({"message": "", "turn_complete": True})
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info("Client disconnected")
|
|
break
|
|
except json.JSONDecodeError:
|
|
logger.warning("Invalid JSON message received")
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"Error in WebSocket message handling: {str(e)}")
|
|
await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
|
|
break
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info("Client disconnected during authentication")
|
|
except json.JSONDecodeError:
|
|
logger.warning("Invalid authentication message format")
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
except Exception as e:
|
|
logger.error(f"Error during authentication: {str(e)}")
|
|
await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket error: {str(e)}")
|
|
await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
|
|
|
|
|
|
@router.post(
|
|
"",
|
|
response_model=ChatResponse,
|
|
responses={
|
|
400: {"model": ErrorResponse},
|
|
404: {"model": ErrorResponse},
|
|
500: {"model": ErrorResponse},
|
|
},
|
|
)
|
|
async def chat(
|
|
request: ChatRequest,
|
|
_=Depends(get_agent_by_api_key),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
try:
|
|
final_response = await run_agent(
|
|
request.agent_id,
|
|
request.external_id,
|
|
request.message,
|
|
session_service,
|
|
artifacts_service,
|
|
memory_service,
|
|
db,
|
|
files=request.files,
|
|
)
|
|
|
|
return {
|
|
"response": final_response["final_response"],
|
|
"message_history": final_response["message_history"],
|
|
"status": "success",
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
except AgentNotFoundError as e:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
|
) from e
|