feat(api): add push notification service and integrate with agent task handling
This commit is contained in:
parent
690168fa5d
commit
465efc6936
@ -1,9 +1,10 @@
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
import os
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Header, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from src.config.database import get_db
|
||||
from typing import List, Dict, Any
|
||||
from typing import List, Dict, Any, Optional
|
||||
import uuid
|
||||
from src.core.jwt_middleware import (
|
||||
get_jwt_token,
|
||||
@ -23,6 +24,7 @@ from src.services.service_providers import (
|
||||
artifacts_service,
|
||||
memory_service,
|
||||
)
|
||||
from src.services.push_notification_service import push_notification_service
|
||||
import logging
|
||||
from fastapi.responses import StreamingResponse
|
||||
from ..services.streaming_service import StreamingService
|
||||
@ -221,18 +223,32 @@ async def get_agent_json(
|
||||
@router.post("/{agent_id}/tasks/send")
|
||||
async def handle_task(
|
||||
agent_id: uuid.UUID,
|
||||
request: Request,
|
||||
request: JSONRPCRequest,
|
||||
x_api_key: str = Header(..., alias="x-api-key"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Endpoint to clients A2A send a new task (with an initial user message)."""
|
||||
try:
|
||||
# Verify JSON-RPC method
|
||||
if request.method != "tasks/send":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.id,
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": "Method not found",
|
||||
"data": {"detail": f"Method '{request.method}' not found"},
|
||||
},
|
||||
}
|
||||
|
||||
# Verify agent
|
||||
agent = agent_service.get_agent(db, agent_id)
|
||||
if agent is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found"
|
||||
)
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.id,
|
||||
"error": {"code": 404, "message": "Agent not found", "data": None},
|
||||
}
|
||||
|
||||
# Verify API key
|
||||
agent_config = agent.config
|
||||
@ -242,29 +258,35 @@ async def handle_task(
|
||||
detail="Invalid API key for this agent",
|
||||
)
|
||||
|
||||
# Process request
|
||||
try:
|
||||
task_request = await request.json()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing request: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request format"
|
||||
)
|
||||
# Extract task request from JSON-RPC params
|
||||
task_request = request.params
|
||||
|
||||
# Validate required fields
|
||||
task_id = task_request.get("id")
|
||||
if not task_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Task ID is required"
|
||||
)
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.id,
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "Invalid parameters",
|
||||
"data": {"detail": "Task ID is required"},
|
||||
},
|
||||
}
|
||||
|
||||
# Extract user message
|
||||
try:
|
||||
user_message = task_request["message"]["parts"][0]["text"]
|
||||
except (KeyError, IndexError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid message format"
|
||||
)
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.id,
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "Invalid parameters",
|
||||
"data": {"detail": "Invalid message format"},
|
||||
},
|
||||
}
|
||||
|
||||
# Configure session and metadata
|
||||
session_id = f"{task_id}_{agent_id}"
|
||||
@ -276,7 +298,7 @@ async def handle_task(
|
||||
"id": task_id,
|
||||
"sessionId": session_id,
|
||||
"status": {
|
||||
"state": "running",
|
||||
"state": "submitted",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": None,
|
||||
"error": None,
|
||||
@ -286,7 +308,50 @@ async def handle_task(
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
# Handle push notification configuration
|
||||
push_notification = task_request.get("pushNotification")
|
||||
if push_notification:
|
||||
url = push_notification.get("url")
|
||||
headers = push_notification.get("headers", {})
|
||||
|
||||
if not url:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.id,
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "Invalid parameters",
|
||||
"data": {"detail": "Push notification URL is required"},
|
||||
},
|
||||
}
|
||||
|
||||
# Store push notification config in metadata
|
||||
response_task["metadata"]["pushNotification"] = {
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
}
|
||||
|
||||
# Send initial notification
|
||||
asyncio.create_task(
|
||||
push_notification_service.send_notification(
|
||||
url=url, task_id=task_id, state="submitted", headers=headers
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Update status to running
|
||||
response_task["status"].update(
|
||||
{"state": "running", "timestamp": datetime.now().isoformat()}
|
||||
)
|
||||
|
||||
# Send running notification if configured
|
||||
if push_notification:
|
||||
asyncio.create_task(
|
||||
push_notification_service.send_notification(
|
||||
url=url, task_id=task_id, state="running", headers=headers
|
||||
)
|
||||
)
|
||||
|
||||
# Execute agent
|
||||
final_response_text = await run_agent(
|
||||
str(agent_id),
|
||||
@ -324,6 +389,21 @@ async def handle_task(
|
||||
}
|
||||
)
|
||||
|
||||
# Send completed notification if configured
|
||||
if push_notification:
|
||||
asyncio.create_task(
|
||||
push_notification_service.send_notification(
|
||||
url=url,
|
||||
task_id=task_id,
|
||||
state="completed",
|
||||
message={
|
||||
"role": "agent",
|
||||
"parts": [{"type": "text", "text": final_response_text}],
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Update status to failed
|
||||
response_task["status"].update(
|
||||
@ -334,6 +414,21 @@ async def handle_task(
|
||||
}
|
||||
)
|
||||
|
||||
# Send failed notification if configured
|
||||
if push_notification:
|
||||
asyncio.create_task(
|
||||
push_notification_service.send_notification(
|
||||
url=url,
|
||||
task_id=task_id,
|
||||
state="failed",
|
||||
message={
|
||||
"role": "system",
|
||||
"parts": [{"type": "text", "text": str(e)}],
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
)
|
||||
|
||||
# Process history
|
||||
try:
|
||||
history_messages = get_session_events(session_service, session_id)
|
||||
@ -361,20 +456,26 @@ async def handle_task(
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing history: {str(e)}")
|
||||
|
||||
# pushNotification = task_request.get("pushNotification", False)
|
||||
# if pushNotification:
|
||||
# await send_push_notification(task_id, final_response_text)
|
||||
# Return JSON-RPC response
|
||||
return {"jsonrpc": "2.0", "id": request.id, "result": response_task}
|
||||
|
||||
return response_task
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except HTTPException as e:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.id,
|
||||
"error": {"code": e.status_code, "message": e.detail, "data": None},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in handle_task: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error",
|
||||
)
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.id,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": "Internal server error",
|
||||
"data": {"detail": str(e)},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{agent_id}/tasks/sendSubscribe")
|
||||
@ -396,8 +497,24 @@ async def subscribe_task_streaming(
|
||||
Returns:
|
||||
StreamingResponse com eventos SSE
|
||||
"""
|
||||
# Verify JSON-RPC method
|
||||
if request.method != "tasks/sendSubscribe":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.id,
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": "Method not found",
|
||||
"data": {"detail": f"Method '{request.method}' not found"},
|
||||
},
|
||||
}
|
||||
|
||||
if not x_api_key:
|
||||
raise HTTPException(status_code=401, detail="API key é obrigatória")
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.id,
|
||||
"error": {"code": 401, "message": "API key é obrigatória", "data": None},
|
||||
}
|
||||
|
||||
# Extrai mensagem do payload
|
||||
message = request.params.get("message", {}).get("parts", [{}])[0].get("text", "")
|
||||
|
64
src/services/push_notification_service.py
Normal file
64
src/services/push_notification_service.py
Normal file
@ -0,0 +1,64 @@
|
||||
import aiohttp
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PushNotificationService:
|
||||
def __init__(self):
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
async def send_notification(
|
||||
self,
|
||||
url: str,
|
||||
task_id: str,
|
||||
state: str,
|
||||
message: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
) -> bool:
|
||||
"""
|
||||
Envia uma notificação push para a URL especificada.
|
||||
Implementa retry com backoff exponencial.
|
||||
"""
|
||||
payload = {
|
||||
"taskId": task_id,
|
||||
"state": state,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": message,
|
||||
}
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with self.session.post(
|
||||
url, json=payload, headers=headers or {}, timeout=10
|
||||
) as response:
|
||||
if response.status in (200, 201, 202, 204):
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
f"Push notification failed with status {response.status}. "
|
||||
f"Attempt {attempt + 1}/{max_retries}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error sending push notification: {str(e)}. "
|
||||
f"Attempt {attempt + 1}/{max_retries}"
|
||||
)
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(retry_delay * (2**attempt))
|
||||
|
||||
return False
|
||||
|
||||
async def close(self):
|
||||
"""Fecha a sessão HTTP"""
|
||||
await self.session.close()
|
||||
|
||||
|
||||
# Instância global do serviço
|
||||
push_notification_service = PushNotificationService()
|
Loading…
Reference in New Issue
Block a user