diff --git a/CHANGELOG.md b/CHANGELOG.md index fd017f10..95abcbb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.0.9] - 2025-05-13 +### Added + +- Add API key sharing and flexible authentication for chat routes + ### Changed - Enhance user authentication with detailed error handling diff --git a/src/api/agent_routes.py b/src/api/agent_routes.py index 2b0d411a..4369f209 100644 --- a/src/api/agent_routes.py +++ b/src/api/agent_routes.py @@ -560,3 +560,64 @@ async def delete_agent( raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found" ) + + +@router.post("/{agent_id}/share", response_model=Dict[str, str]) +async def share_agent( + agent_id: uuid.UUID, + x_client_id: uuid.UUID = Header(..., alias="x-client-id"), + db: Session = Depends(get_db), + payload: dict = Depends(get_jwt_token), +): + """Returns the agent's API key for sharing""" + await verify_user_client(payload, db, x_client_id) + + # Verify if the agent exists + agent = agent_service.get_agent(db, agent_id) + if not agent: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found" + ) + + # Verify if the agent belongs to the specified client + if agent.client_id != x_client_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Agent does not belong to the specified client", + ) + + # Verify if API key exists + if not agent.config or not agent.config.get("api_key"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="This agent does not have an API key", + ) + + return {"api_key": agent.config["api_key"]} + + +@router.get("/{agent_id}/shared", response_model=Agent) +async def get_shared_agent( + agent_id: uuid.UUID, + api_key: str = Header(..., alias="x-api-key"), + db: Session = Depends(get_db), +): + """Get agent details using only API key authentication""" + # Verify if the agent exists + agent = agent_service.get_agent(db, agent_id) + if not agent or not agent.config: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found" + ) + + # Verify if the API key matches + if not agent.config.get("api_key") or agent.config.get("api_key") != api_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key" + ) + + # Add agent card URL if not present + if not agent.agent_card_url: + agent.agent_card_url = agent.agent_card_url_property + + return agent diff --git a/src/api/chat_routes.py b/src/api/chat_routes.py index f42d3d54..8846080b 100644 --- a/src/api/chat_routes.py +++ b/src/api/chat_routes.py @@ -27,6 +27,7 @@ └──────────────────────────────────────────────────────────────────────────────┘ """ +import uuid from fastapi import ( APIRouter, Depends, @@ -34,6 +35,7 @@ from fastapi import ( status, WebSocket, WebSocketDisconnect, + Header, ) from sqlalchemy.orm import Session from src.config.database import get_db @@ -57,6 +59,7 @@ from src.services.service_providers import ( from datetime import datetime import logging import json +from typing import Optional, Dict logger = logging.getLogger(__name__) @@ -67,6 +70,59 @@ router = APIRouter( ) +async def get_agent_by_api_key( + agent_id: str, + api_key: Optional[str] = Header(None, alias="x-api-key"), + authorization: Optional[str] = Header(None), + db: Session = Depends(get_db), +): + """Flexible authentication for chat routes, allowing JWT or API key""" + if authorization: + # Try to authenticate with JWT token first + try: + # Extract token from Authorization header if needed + token = ( + authorization.replace("Bearer ", "") + if authorization.startswith("Bearer ") + else authorization + ) + payload = await get_jwt_token(token) + agent = agent_service.get_agent(db, agent_id) + if not agent: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Agent not found", + ) + + # Verify if the user has access to the agent's client + await verify_user_client(payload, db, agent.client_id) + return agent + except Exception as e: + logger.warning(f"JWT authentication failed: {str(e)}") + # If JWT fails, continue to try with API key + + # Try to authenticate with API key + if not api_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication required (JWT or API key)", + ) + + agent = agent_service.get_agent(db, agent_id) + if not agent or not agent.config: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found" + ) + + # Verify if the API key matches + if not agent.config.get("api_key") or agent.config.get("api_key") != api_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key" + ) + + return agent + + @router.websocket("/ws/{agent_id}/{external_id}") async def websocket_chat( websocket: WebSocket, @@ -82,32 +138,49 @@ async def websocket_chat( # Wait for authentication message try: auth_data = await websocket.receive_json() - logger.info(f"Received authentication data: {auth_data}") + logger.info(f"Authentication data received: {auth_data}") if not ( - auth_data.get("type") == "authorization" and auth_data.get("token") + auth_data.get("type") == "authorization" + and (auth_data.get("token") or auth_data.get("api_key")) ): logger.warning("Invalid authentication message") await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return - token = auth_data["token"] - # Verify the token - payload = await get_jwt_token_ws(token) - if not payload: - logger.warning("Invalid token") - await websocket.close(code=status.WS_1008_POLICY_VIOLATION) - return - - # Verify if the agent belongs to the user's client + # Verify if the agent exists agent = agent_service.get_agent(db, agent_id) if not agent: logger.warning(f"Agent {agent_id} not found") await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return - # Verify if the user has access to the agent (via client) - await verify_user_client(payload, db, agent.client_id) + # Verify authentication + is_authenticated = False + + # Try with JWT token + if auth_data.get("token"): + try: + payload = await get_jwt_token_ws(auth_data["token"]) + if payload: + # Verify if the user has access to the agent + await verify_user_client(payload, db, agent.client_id) + is_authenticated = True + except Exception as e: + logger.warning(f"JWT authentication failed: {str(e)}") + + # If JWT fails, try with API key + if not is_authenticated and auth_data.get("api_key"): + if agent.config and agent.config.get("api_key") == auth_data.get( + "api_key" + ): + is_authenticated = True + else: + logger.warning("Invalid API key") + + if not is_authenticated: + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + return logger.info( f"WebSocket connection established for agent {agent_id} and external_id {external_id}" @@ -174,19 +247,9 @@ async def websocket_chat( ) async def chat( request: ChatRequest, + _=Depends(get_agent_by_api_key), db: Session = Depends(get_db), - payload: dict = Depends(get_jwt_token), ): - # Verify if the agent belongs to the user's client - agent = agent_service.get_agent(db, request.agent_id) - if not agent: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found" - ) - - # Verify if the user has access to the agent (via client) - await verify_user_client(payload, db, agent.client_id) - try: final_response = await run_agent( request.agent_id, diff --git a/src/services/a2a_task_manager.py b/src/services/a2a_task_manager.py index 581a2a26..ce86676d 100644 --- a/src/services/a2a_task_manager.py +++ b/src/services/a2a_task_manager.py @@ -411,37 +411,14 @@ class A2ATaskManager: ): try: chunk_data = json.loads(chunk) - except Exception as e: - logger.warning(f"Invalid chunk received: {chunk} - {e}") - continue - if ( - isinstance(chunk_data, dict) - and "type" in chunk_data - and chunk_data["type"] - in [ - "history", - "history_update", - "history_complete", - ] - ): - continue + if isinstance(chunk_data, dict) and "content" in chunk_data: + content = chunk_data.get("content", {}) + role = content.get("role", "agent") + parts = content.get("parts", []) - if isinstance(chunk_data, dict): - if "type" not in chunk_data and "text" in chunk_data: - chunk_data["type"] = "text" - - if "type" in chunk_data: - try: - update_message = Message(role="agent", parts=[chunk_data]) - - await self.update_store( - request.params.id, - TaskStatus( - state=TaskState.WORKING, message=update_message - ), - update_history=False, - ) + if parts: + update_message = Message(role=role, parts=parts) yield SendTaskStreamingResponse( id=request.id, @@ -455,14 +432,13 @@ class A2ATaskManager: ), ) - if chunk_data.get("type") == "text": - full_response += chunk_data.get("text", "") - final_message = update_message - - except Exception as e: - logger.error( - f"Error processing chunk: {e}, chunk: {chunk_data}" - ) + for part in parts: + if part.get("type") == "text": + full_response += part.get("text", "") + final_message = update_message + except Exception as e: + logger.error(f"Error processing chunk: {e}, chunk: {chunk}") + continue # Determine the final state of the task task_state = (