refactor(a2a): enhance A2A request processing and task management with session ID handling
This commit is contained in:
parent
8f1fef71a5
commit
a0f984ae21
@ -124,7 +124,7 @@ def get_task_manager(agent_id, db=None, reuse=True, operation_type="query"):
|
|||||||
return task_manager
|
return task_manager
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{agent_id}/rpc")
|
@router.post("/{agent_id}")
|
||||||
async def process_a2a_request(
|
async def process_a2a_request(
|
||||||
agent_id: uuid.UUID,
|
agent_id: uuid.UUID,
|
||||||
request: Request,
|
request: Request,
|
||||||
@ -156,6 +156,13 @@ async def process_a2a_request(
|
|||||||
try:
|
try:
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
method = body.get("method", "unknown")
|
method = body.get("method", "unknown")
|
||||||
|
request_id = body.get("id") # Extract request ID to ensure it's preserved
|
||||||
|
|
||||||
|
# Extrair sessionId do params se for tasks/send
|
||||||
|
session_id = None
|
||||||
|
if method == "tasks/send" and body.get("params"):
|
||||||
|
session_id = body.get("params", {}).get("sessionId")
|
||||||
|
logger.info(f"Extracted sessionId from request: {session_id}")
|
||||||
|
|
||||||
is_query_request = method in [
|
is_query_request = method in [
|
||||||
"tasks/get",
|
"tasks/get",
|
||||||
@ -189,7 +196,7 @@ async def process_a2a_request(
|
|||||||
status_code=404,
|
status_code=404,
|
||||||
content={
|
content={
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": None,
|
"id": request_id, # Use the extracted request ID
|
||||||
"error": {"code": 404, "message": "Agent not found", "data": None},
|
"error": {"code": 404, "message": "Agent not found", "data": None},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -203,7 +210,7 @@ async def process_a2a_request(
|
|||||||
status_code=401,
|
status_code=401,
|
||||||
content={
|
content={
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": None,
|
"id": request_id, # Use the extracted request ID
|
||||||
"error": {"code": 401, "message": "Invalid API key", "data": None},
|
"error": {"code": 401, "message": "Invalid API key", "data": None},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -226,7 +233,7 @@ async def process_a2a_request(
|
|||||||
status_code=400,
|
status_code=400,
|
||||||
content={
|
content={
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": body.get("id"),
|
"id": request_id, # Use the extracted request ID
|
||||||
"error": {
|
"error": {
|
||||||
"code": -32600,
|
"code": -32600,
|
||||||
"message": "Invalid Request: jsonrpc must be '2.0'",
|
"message": "Invalid Request: jsonrpc must be '2.0'",
|
||||||
@ -241,7 +248,7 @@ async def process_a2a_request(
|
|||||||
status_code=400,
|
status_code=400,
|
||||||
content={
|
content={
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": body.get("id"),
|
"id": request_id, # Use the extracted request ID
|
||||||
"error": {
|
"error": {
|
||||||
"code": -32600,
|
"code": -32600,
|
||||||
"message": "Invalid Request: method is required",
|
"message": "Invalid Request: method is required",
|
||||||
@ -250,15 +257,24 @@ async def process_a2a_request(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Processar a requisição normalmente
|
||||||
return await a2a_server.process_request(request, agent_id=str(agent_id), db=db)
|
return await a2a_server.process_request(request, agent_id=str(agent_id), db=db)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing A2A request: {str(e)}", exc_info=True)
|
logger.error(f"Error processing A2A request: {str(e)}", exc_info=True)
|
||||||
|
# Try to extract request ID from the body, if available
|
||||||
|
request_id = None
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
request_id = body.get("id")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
content={
|
content={
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": None,
|
"id": request_id, # Use the extracted request ID or None
|
||||||
"error": {
|
"error": {
|
||||||
"code": -32603,
|
"code": -32603,
|
||||||
"message": "Internal server error",
|
"message": "Internal server error",
|
||||||
|
@ -88,6 +88,7 @@ class AgentRunnerAdapter:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Use the existing agent runner function
|
# Use the existing agent runner function
|
||||||
|
# Usar o session_id fornecido, ou gerar um novo
|
||||||
session_id = session_id or str(uuid.uuid4())
|
session_id = session_id or str(uuid.uuid4())
|
||||||
task_id = task_id or str(uuid.uuid4())
|
task_id = task_id or str(uuid.uuid4())
|
||||||
|
|
||||||
@ -105,12 +106,28 @@ class AgentRunnerAdapter:
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Format the response to include both the A2A-compatible message
|
||||||
|
# and the artifact format to match the Google A2A implementation
|
||||||
|
# Nota: O formato dos artifacts é simplificado para compatibilidade com Google A2A
|
||||||
|
message_obj = {
|
||||||
|
"role": "agent",
|
||||||
|
"parts": [{"type": "text", "text": response_text}],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Formato de artefato compatível com Google A2A
|
||||||
|
artifact_obj = {
|
||||||
|
"parts": [{"type": "text", "text": response_text}],
|
||||||
|
"index": 0,
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"content": response_text,
|
"content": response_text,
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"task_id": task_id,
|
"task_id": task_id,
|
||||||
"timestamp": datetime.now().isoformat(),
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"message": message_obj,
|
||||||
|
"artifact": artifact_obj,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -121,6 +138,10 @@ class AgentRunnerAdapter:
|
|||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"task_id": task_id,
|
"task_id": task_id,
|
||||||
"timestamp": datetime.now().isoformat(),
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"message": {
|
||||||
|
"role": "agent",
|
||||||
|
"parts": [{"type": "text", "text": f"Error: {str(e)}"}],
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
async def cancel_task(self, task_id: str) -> bool:
|
async def cancel_task(self, task_id: str) -> bool:
|
||||||
|
@ -9,6 +9,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, Union, AsyncIterable
|
from typing import Any, Dict, Union, AsyncIterable
|
||||||
|
import uuid
|
||||||
|
|
||||||
from src.schemas.a2a.exceptions import (
|
from src.schemas.a2a.exceptions import (
|
||||||
TaskNotFoundError,
|
TaskNotFoundError,
|
||||||
@ -263,8 +264,9 @@ class A2ATaskManager:
|
|||||||
f"task_notification:{task_id}", params.pushNotification.model_dump()
|
f"task_notification:{task_id}", params.pushNotification.model_dump()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start task execution in background
|
# Execute task SYNCHRONOUSLY instead of in background
|
||||||
asyncio.create_task(self._execute_task(task_data, params))
|
# This is the key change for A2A compatibility
|
||||||
|
task_data = await self._execute_task(task_data, params)
|
||||||
|
|
||||||
# Convert to Task object and return
|
# Convert to Task object and return
|
||||||
task = Task.model_validate(task_data)
|
task = Task.model_validate(task_data)
|
||||||
@ -523,7 +525,8 @@ class A2ATaskManager:
|
|||||||
# Create task with initial status
|
# Create task with initial status
|
||||||
task_data = {
|
task_data = {
|
||||||
"id": params.id,
|
"id": params.id,
|
||||||
"sessionId": params.sessionId,
|
"sessionId": params.sessionId
|
||||||
|
or str(uuid.uuid4()), # Preservar sessionId quando fornecido
|
||||||
"status": {
|
"status": {
|
||||||
"state": TaskState.SUBMITTED,
|
"state": TaskState.SUBMITTED,
|
||||||
"timestamp": datetime.now().isoformat(),
|
"timestamp": datetime.now().isoformat(),
|
||||||
@ -531,7 +534,7 @@ class A2ATaskManager:
|
|||||||
"error": None,
|
"error": None,
|
||||||
},
|
},
|
||||||
"artifacts": [],
|
"artifacts": [],
|
||||||
"history": [params.message.model_dump()],
|
"history": [params.message.model_dump()], # Apenas mensagem do usuário
|
||||||
"metadata": params.metadata or {},
|
"metadata": params.metadata or {},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -540,7 +543,9 @@ class A2ATaskManager:
|
|||||||
|
|
||||||
return task_data
|
return task_data
|
||||||
|
|
||||||
async def _execute_task(self, task: Dict[str, Any], params: TaskSendParams) -> None:
|
async def _execute_task(
|
||||||
|
self, task: Dict[str, Any], params: TaskSendParams
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Execute a task using the agent adapter.
|
Execute a task using the agent adapter.
|
||||||
|
|
||||||
@ -550,6 +555,9 @@ class A2ATaskManager:
|
|||||||
Args:
|
Args:
|
||||||
task: Task data to be executed
|
task: Task data to be executed
|
||||||
params: Send task parameters
|
params: Send task parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated task data with completed status and response
|
||||||
"""
|
"""
|
||||||
task_id = task["id"]
|
task_id = task["id"]
|
||||||
agent_id = params.agentId
|
agent_id = params.agentId
|
||||||
@ -562,20 +570,22 @@ class A2ATaskManager:
|
|||||||
message_text += part.text
|
message_text += part.text
|
||||||
|
|
||||||
if not message_text:
|
if not message_text:
|
||||||
await self._update_task_status(
|
await self._update_task_status_without_history(
|
||||||
task_id, TaskState.FAILED, "Message does not contain text", final=True
|
task_id, TaskState.FAILED, "Message does not contain text", final=True
|
||||||
)
|
)
|
||||||
return
|
# Return the updated task data
|
||||||
|
return await self.redis_cache.get(f"task:{task_id}")
|
||||||
|
|
||||||
# Check if it is an ongoing execution
|
# Check if it is an ongoing execution
|
||||||
task_status = task.get("status", {})
|
task_status = task.get("status", {})
|
||||||
if task_status.get("state") in [TaskState.WORKING, TaskState.COMPLETED]:
|
if task_status.get("state") in [TaskState.WORKING, TaskState.COMPLETED]:
|
||||||
logger.info(f"Task {task_id} is already in execution or completed")
|
logger.info(f"Task {task_id} is already in execution or completed")
|
||||||
return
|
# Return the current task data
|
||||||
|
return await self.redis_cache.get(f"task:{task_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Update to "working" state
|
# Update to "working" state - NÃO adicionar ao histórico
|
||||||
await self._update_task_status(
|
await self._update_task_status_without_history(
|
||||||
task_id, TaskState.WORKING, "Processing request"
|
task_id, TaskState.WORKING, "Processing request"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -584,7 +594,7 @@ class A2ATaskManager:
|
|||||||
response = await self.agent_runner.run_agent(
|
response = await self.agent_runner.run_agent(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
message=message_text,
|
message=message_text,
|
||||||
session_id=params.sessionId,
|
session_id=params.sessionId, # Usar o sessionId da requisição
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -601,37 +611,38 @@ class A2ATaskManager:
|
|||||||
|
|
||||||
# Build the final agent message
|
# Build the final agent message
|
||||||
if response_text:
|
if response_text:
|
||||||
# Create an artifact for the response
|
# Atualizar o histórico com a mensagem do usuário
|
||||||
artifact = Artifact(
|
await self._update_task_history(task_id, params.message)
|
||||||
name="response",
|
|
||||||
parts=[TextPart(text=response_text)],
|
# Create an artifact for the response in Google A2A format
|
||||||
index=0,
|
artifact = {
|
||||||
lastChunk=True,
|
"parts": [{"type": "text", "text": response_text}],
|
||||||
)
|
"index": 0,
|
||||||
|
}
|
||||||
|
|
||||||
# Add the artifact to the task
|
# Add the artifact to the task
|
||||||
await self._add_task_artifact(task_id, artifact)
|
await self._add_task_artifact(task_id, artifact)
|
||||||
|
|
||||||
# Update the task status to completed
|
# Update the task status to completed (sem adicionar ao histórico)
|
||||||
await self._update_task_status(
|
await self._update_task_status_without_history(
|
||||||
task_id, TaskState.COMPLETED, response_text, final=True
|
task_id, TaskState.COMPLETED, response_text, final=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self._update_task_status(
|
await self._update_task_status_without_history(
|
||||||
task_id,
|
task_id,
|
||||||
TaskState.FAILED,
|
TaskState.FAILED,
|
||||||
"The agent did not return a valid response",
|
"The agent did not return a valid response",
|
||||||
final=True,
|
final=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self._update_task_status(
|
await self._update_task_status_without_history(
|
||||||
task_id,
|
task_id,
|
||||||
TaskState.FAILED,
|
TaskState.FAILED,
|
||||||
"Invalid agent response",
|
"Invalid agent response",
|
||||||
final=True,
|
final=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self._update_task_status(
|
await self._update_task_status_without_history(
|
||||||
task_id,
|
task_id,
|
||||||
TaskState.FAILED,
|
TaskState.FAILED,
|
||||||
"Agent adapter not configured",
|
"Agent adapter not configured",
|
||||||
@ -639,15 +650,78 @@ class A2ATaskManager:
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error executing task {task_id}: {str(e)}")
|
logger.error(f"Error executing task {task_id}: {str(e)}")
|
||||||
await self._update_task_status(
|
await self._update_task_status_without_history(
|
||||||
task_id, TaskState.FAILED, f"Error processing: {str(e)}", final=True
|
task_id, TaskState.FAILED, f"Error processing: {str(e)}", final=True
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _update_task_status(
|
# Return the updated task data
|
||||||
|
return await self.redis_cache.get(f"task:{task_id}")
|
||||||
|
|
||||||
|
async def _update_task_history(self, task_id: str, message) -> None:
|
||||||
|
"""
|
||||||
|
Atualiza o histórico da tarefa incluindo apenas mensagens do usuário.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: ID da tarefa
|
||||||
|
message: Mensagem do usuário para adicionar ao histórico
|
||||||
|
"""
|
||||||
|
if not message:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Obter dados da tarefa atual
|
||||||
|
task_data = await self.redis_cache.get(f"task:{task_id}")
|
||||||
|
if not task_data:
|
||||||
|
logger.warning(f"Task {task_id} not found for history update")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Garantir que há um campo de histórico
|
||||||
|
if "history" not in task_data:
|
||||||
|
task_data["history"] = []
|
||||||
|
|
||||||
|
# Verificar se a mensagem já existe no histórico para evitar duplicação
|
||||||
|
user_message = (
|
||||||
|
message.model_dump() if hasattr(message, "model_dump") else message
|
||||||
|
)
|
||||||
|
message_exists = False
|
||||||
|
|
||||||
|
for msg in task_data["history"]:
|
||||||
|
if self._compare_messages(msg, user_message):
|
||||||
|
message_exists = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# Adicionar mensagem se não existir
|
||||||
|
if not message_exists:
|
||||||
|
task_data["history"].append(user_message)
|
||||||
|
logger.info(f"Added new user message to history for task {task_id}")
|
||||||
|
|
||||||
|
# Salvar tarefa atualizada
|
||||||
|
await self.redis_cache.set(f"task:{task_id}", task_data)
|
||||||
|
|
||||||
|
# Método auxiliar para comparar mensagens e evitar duplicação no histórico
|
||||||
|
def _compare_messages(self, msg1: Dict, msg2: Dict) -> bool:
|
||||||
|
"""Compara duas mensagens para verificar se são essencialmente iguais."""
|
||||||
|
if msg1.get("role") != msg2.get("role"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
parts1 = msg1.get("parts", [])
|
||||||
|
parts2 = msg2.get("parts", [])
|
||||||
|
|
||||||
|
if len(parts1) != len(parts2):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for i in range(len(parts1)):
|
||||||
|
if parts1[i].get("type") != parts2[i].get("type"):
|
||||||
|
return False
|
||||||
|
if parts1[i].get("text") != parts2[i].get("text"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _update_task_status_without_history(
|
||||||
self, task_id: str, state: TaskState, message_text: str, final: bool = False
|
self, task_id: str, state: TaskState, message_text: str, final: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Update the status of a task.
|
Update the status of a task without changing the history.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task_id: ID of the task to be updated
|
task_id: ID of the task to be updated
|
||||||
@ -666,7 +740,6 @@ class A2ATaskManager:
|
|||||||
agent_message = Message(
|
agent_message = Message(
|
||||||
role="agent",
|
role="agent",
|
||||||
parts=[TextPart(text=message_text)],
|
parts=[TextPart(text=message_text)],
|
||||||
metadata={"timestamp": datetime.now().isoformat()},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
status = TaskStatus(
|
status = TaskStatus(
|
||||||
@ -676,13 +749,6 @@ class A2ATaskManager:
|
|||||||
# Update the status in the task
|
# Update the status in the task
|
||||||
task_data["status"] = status.model_dump(exclude_none=True)
|
task_data["status"] = status.model_dump(exclude_none=True)
|
||||||
|
|
||||||
# Update the history, if it exists
|
|
||||||
if "history" not in task_data:
|
|
||||||
task_data["history"] = []
|
|
||||||
|
|
||||||
# Add the message to the history
|
|
||||||
task_data["history"].append(agent_message.model_dump(exclude_none=True))
|
|
||||||
|
|
||||||
# Store the updated task
|
# Store the updated task
|
||||||
await self.redis_cache.set(f"task:{task_id}", task_data)
|
await self.redis_cache.set(f"task:{task_id}", task_data)
|
||||||
|
|
||||||
@ -704,13 +770,13 @@ class A2ATaskManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error updating task status {task_id}: {str(e)}")
|
logger.error(f"Error updating task status {task_id}: {str(e)}")
|
||||||
|
|
||||||
async def _add_task_artifact(self, task_id: str, artifact: Artifact) -> None:
|
async def _add_task_artifact(self, task_id: str, artifact) -> None:
|
||||||
"""
|
"""
|
||||||
Add an artifact to a task and publish the update.
|
Add an artifact to a task and publish the update.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task_id: Task ID
|
task_id: Task ID
|
||||||
artifact: Artifact to add
|
artifact: Artifact to add (dict no formato do Google)
|
||||||
"""
|
"""
|
||||||
logger.info(f"Adding artifact to task {task_id}")
|
logger.info(f"Adding artifact to task {task_id}")
|
||||||
|
|
||||||
@ -720,13 +786,21 @@ class A2ATaskManager:
|
|||||||
if "artifacts" not in task_data:
|
if "artifacts" not in task_data:
|
||||||
task_data["artifacts"] = []
|
task_data["artifacts"] = []
|
||||||
|
|
||||||
# Convert artifact to dict
|
# Adicionar o artefato sem substituir os existentes
|
||||||
artifact_dict = artifact.model_dump()
|
task_data["artifacts"].append(artifact)
|
||||||
task_data["artifacts"].append(artifact_dict)
|
|
||||||
await self.redis_cache.set(f"task:{task_id}", task_data)
|
await self.redis_cache.set(f"task:{task_id}", task_data)
|
||||||
|
|
||||||
|
# Criar um artefato do tipo Artifact para o evento
|
||||||
|
artifact_obj = Artifact(
|
||||||
|
parts=[
|
||||||
|
TextPart(text=part.get("text", ""))
|
||||||
|
for part in artifact.get("parts", [])
|
||||||
|
],
|
||||||
|
index=artifact.get("index", 0),
|
||||||
|
)
|
||||||
|
|
||||||
# Create artifact update event
|
# Create artifact update event
|
||||||
event = TaskArtifactUpdateEvent(id=task_id, artifact=artifact)
|
event = TaskArtifactUpdateEvent(id=task_id, artifact=artifact_obj)
|
||||||
|
|
||||||
# Publish event
|
# Publish event
|
||||||
await self._publish_task_update(task_id, event)
|
await self._publish_task_update(task_id, event)
|
||||||
|
@ -189,83 +189,114 @@ async def create_agent(db: Session, agent: AgentCreate) -> Agent:
|
|||||||
logger.info("Generating automatic API key for new agent")
|
logger.info("Generating automatic API key for new agent")
|
||||||
config["api_key"] = generate_api_key()
|
config["api_key"] = generate_api_key()
|
||||||
|
|
||||||
if isinstance(config, dict):
|
# Preservar todos os campos originais
|
||||||
# Process MCP servers
|
processed_config = {}
|
||||||
if "mcp_servers" in config:
|
processed_config["api_key"] = config.get("api_key", "")
|
||||||
if config["mcp_servers"] is not None:
|
|
||||||
processed_servers = []
|
|
||||||
for server in config["mcp_servers"]:
|
|
||||||
# Convert server id to UUID if it's a string
|
|
||||||
server_id = server["id"]
|
|
||||||
if isinstance(server_id, str):
|
|
||||||
server_id = uuid.UUID(server_id)
|
|
||||||
|
|
||||||
# Search for MCP server in the database
|
# Copiar campos originais
|
||||||
mcp_server = get_mcp_server(db, server_id)
|
if "tools" in config:
|
||||||
if not mcp_server:
|
processed_config["tools"] = config["tools"]
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"MCP server not found: {server['id']}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if all required environment variables are provided
|
if "custom_tools" in config:
|
||||||
for env_key, env_value in mcp_server.environments.items():
|
processed_config["custom_tools"] = config["custom_tools"]
|
||||||
if env_key not in server.get("envs", {}):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Environment variable '{env_key}' not provided for MCP server {mcp_server.name}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the processed server with its tools
|
if "sub_agents" in config:
|
||||||
processed_servers.append(
|
processed_config["sub_agents"] = config["sub_agents"]
|
||||||
{
|
|
||||||
"id": str(server["id"]),
|
if "custom_mcp_servers" in config:
|
||||||
"envs": server["envs"],
|
processed_config["custom_mcp_servers"] = config["custom_mcp_servers"]
|
||||||
"tools": server["tools"],
|
|
||||||
}
|
# Preservar outros campos não processados especificamente
|
||||||
|
for key, value in config.items():
|
||||||
|
if key not in [
|
||||||
|
"api_key",
|
||||||
|
"tools",
|
||||||
|
"custom_tools",
|
||||||
|
"sub_agents",
|
||||||
|
"custom_mcp_servers",
|
||||||
|
"mcp_servers",
|
||||||
|
]:
|
||||||
|
processed_config[key] = value
|
||||||
|
|
||||||
|
# Processar apenas campos que precisam de processamento
|
||||||
|
# Process MCP servers
|
||||||
|
if "mcp_servers" in config and config["mcp_servers"] is not None:
|
||||||
|
processed_servers = []
|
||||||
|
for server in config["mcp_servers"]:
|
||||||
|
# Convert server id to UUID if it's a string
|
||||||
|
server_id = server["id"]
|
||||||
|
if isinstance(server_id, str):
|
||||||
|
server_id = uuid.UUID(server_id)
|
||||||
|
|
||||||
|
# Search for MCP server in the database
|
||||||
|
mcp_server = get_mcp_server(db, server_id)
|
||||||
|
if not mcp_server:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"MCP server not found: {server['id']}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if all required environment variables are provided
|
||||||
|
for env_key, env_value in mcp_server.environments.items():
|
||||||
|
if env_key not in server.get("envs", {}):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Environment variable '{env_key}' not provided for MCP server {mcp_server.name}",
|
||||||
)
|
)
|
||||||
|
|
||||||
config["mcp_servers"] = processed_servers
|
# Add the processed server
|
||||||
else:
|
processed_servers.append(
|
||||||
config["mcp_servers"] = []
|
{
|
||||||
|
"id": str(server["id"]),
|
||||||
|
"envs": server["envs"],
|
||||||
|
"tools": server["tools"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Process custom MCP servers
|
processed_config["mcp_servers"] = processed_servers
|
||||||
if "custom_mcp_servers" in config:
|
elif "mcp_servers" in config:
|
||||||
if config["custom_mcp_servers"] is not None:
|
processed_config["mcp_servers"] = config["mcp_servers"]
|
||||||
processed_custom_servers = []
|
|
||||||
for server in config["custom_mcp_servers"]:
|
|
||||||
# Validate URL format
|
|
||||||
if not server.get("url"):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="URL is required for custom MCP servers",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the custom server
|
# Process custom MCP servers
|
||||||
processed_custom_servers.append(
|
if "custom_mcp_servers" in config and config["custom_mcp_servers"] is not None:
|
||||||
{"url": server["url"], "headers": server.get("headers", {})}
|
processed_custom_servers = []
|
||||||
)
|
for server in config["custom_mcp_servers"]:
|
||||||
|
# Validate URL format
|
||||||
|
if not server.get("url"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="URL is required for custom MCP servers",
|
||||||
|
)
|
||||||
|
|
||||||
config["custom_mcp_servers"] = processed_custom_servers
|
# Add the custom server
|
||||||
else:
|
processed_custom_servers.append(
|
||||||
config["custom_mcp_servers"] = []
|
{"url": server["url"], "headers": server.get("headers", {})}
|
||||||
|
)
|
||||||
|
|
||||||
# Process sub-agents
|
processed_config["custom_mcp_servers"] = processed_custom_servers
|
||||||
if "sub_agents" in config:
|
|
||||||
if config["sub_agents"] is not None:
|
|
||||||
config["sub_agents"] = [
|
|
||||||
str(agent_id) for agent_id in config["sub_agents"]
|
|
||||||
]
|
|
||||||
|
|
||||||
# Process tools
|
# Process sub-agents
|
||||||
if "tools" in config:
|
if "sub_agents" in config and config["sub_agents"] is not None:
|
||||||
if config["tools"] is not None:
|
processed_config["sub_agents"] = [
|
||||||
config["tools"] = [
|
str(agent_id) for agent_id in config["sub_agents"]
|
||||||
{"id": str(tool["id"]), "envs": tool["envs"]}
|
]
|
||||||
for tool in config["tools"]
|
|
||||||
]
|
|
||||||
|
|
||||||
agent.config = config
|
# Process tools
|
||||||
|
if "tools" in config and config["tools"] is not None:
|
||||||
|
processed_tools = []
|
||||||
|
for tool in config["tools"]:
|
||||||
|
# Convert tool id to string
|
||||||
|
tool_id = tool["id"]
|
||||||
|
|
||||||
|
# Validar envs para garantir que não é None
|
||||||
|
envs = tool.get("envs", {})
|
||||||
|
if envs is None:
|
||||||
|
envs = {}
|
||||||
|
|
||||||
|
processed_tools.append({"id": str(tool_id), "envs": envs})
|
||||||
|
processed_config["tools"] = processed_tools
|
||||||
|
|
||||||
|
agent.config = processed_config
|
||||||
|
|
||||||
# Ensure all config objects are serializable (convert UUIDs to strings)
|
# Ensure all config objects are serializable (convert UUIDs to strings)
|
||||||
if agent.config is not None:
|
if agent.config is not None:
|
||||||
@ -398,82 +429,117 @@ async def update_agent(
|
|||||||
if "config" in agent_data:
|
if "config" in agent_data:
|
||||||
config = agent_data["config"]
|
config = agent_data["config"]
|
||||||
|
|
||||||
|
# Preservar todos os campos originais
|
||||||
|
processed_config = {}
|
||||||
|
processed_config["api_key"] = config.get("api_key", "")
|
||||||
|
|
||||||
|
# Copiar campos originais
|
||||||
|
if "tools" in config:
|
||||||
|
processed_config["tools"] = config["tools"]
|
||||||
|
|
||||||
|
if "custom_tools" in config:
|
||||||
|
processed_config["custom_tools"] = config["custom_tools"]
|
||||||
|
|
||||||
|
if "sub_agents" in config:
|
||||||
|
processed_config["sub_agents"] = config["sub_agents"]
|
||||||
|
|
||||||
|
if "custom_mcp_servers" in config:
|
||||||
|
processed_config["custom_mcp_servers"] = config["custom_mcp_servers"]
|
||||||
|
|
||||||
|
# Preservar outros campos não processados especificamente
|
||||||
|
for key, value in config.items():
|
||||||
|
if key not in [
|
||||||
|
"api_key",
|
||||||
|
"tools",
|
||||||
|
"custom_tools",
|
||||||
|
"sub_agents",
|
||||||
|
"custom_mcp_servers",
|
||||||
|
"mcp_servers",
|
||||||
|
]:
|
||||||
|
processed_config[key] = value
|
||||||
|
|
||||||
|
# Processar apenas campos que precisam de processamento
|
||||||
# Process MCP servers
|
# Process MCP servers
|
||||||
if "mcp_servers" in config:
|
if "mcp_servers" in config and config["mcp_servers"] is not None:
|
||||||
if config["mcp_servers"] is not None:
|
processed_servers = []
|
||||||
processed_servers = []
|
for server in config["mcp_servers"]:
|
||||||
for server in config["mcp_servers"]:
|
# Convert server id to UUID if it's a string
|
||||||
# Convert server id to UUID if it's a string
|
server_id = server["id"]
|
||||||
server_id = server["id"]
|
if isinstance(server_id, str):
|
||||||
if isinstance(server_id, str):
|
server_id = uuid.UUID(server_id)
|
||||||
server_id = uuid.UUID(server_id)
|
|
||||||
|
|
||||||
# Search for MCP server in the database
|
# Search for MCP server in the database
|
||||||
mcp_server = get_mcp_server(db, server_id)
|
mcp_server = get_mcp_server(db, server_id)
|
||||||
if not mcp_server:
|
if not mcp_server:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"MCP server not found: {server['id']}",
|
detail=f"MCP server not found: {server['id']}",
|
||||||
)
|
|
||||||
|
|
||||||
# Check if all required environment variables are provided
|
|
||||||
for env_key, env_value in mcp_server.environments.items():
|
|
||||||
if env_key not in server.get("envs", {}):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Environment variable '{env_key}' not provided for MCP server {mcp_server.name}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the processed server
|
|
||||||
processed_servers.append(
|
|
||||||
{
|
|
||||||
"id": str(server["id"]),
|
|
||||||
"envs": server["envs"],
|
|
||||||
"tools": server["tools"],
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
config["mcp_servers"] = processed_servers
|
# Check if all required environment variables are provided
|
||||||
else:
|
for env_key, env_value in mcp_server.environments.items():
|
||||||
config["mcp_servers"] = []
|
if env_key not in server.get("envs", {}):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Environment variable '{env_key}' not provided for MCP server {mcp_server.name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the processed server
|
||||||
|
processed_servers.append(
|
||||||
|
{
|
||||||
|
"id": str(server["id"]),
|
||||||
|
"envs": server["envs"],
|
||||||
|
"tools": server["tools"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
processed_config["mcp_servers"] = processed_servers
|
||||||
|
elif "mcp_servers" in config:
|
||||||
|
processed_config["mcp_servers"] = config["mcp_servers"]
|
||||||
|
|
||||||
# Process custom MCP servers
|
# Process custom MCP servers
|
||||||
if "custom_mcp_servers" in config:
|
if (
|
||||||
if config["custom_mcp_servers"] is not None:
|
"custom_mcp_servers" in config
|
||||||
processed_custom_servers = []
|
and config["custom_mcp_servers"] is not None
|
||||||
for server in config["custom_mcp_servers"]:
|
):
|
||||||
# Validate URL format
|
processed_custom_servers = []
|
||||||
if not server.get("url"):
|
for server in config["custom_mcp_servers"]:
|
||||||
raise HTTPException(
|
# Validate URL format
|
||||||
status_code=400,
|
if not server.get("url"):
|
||||||
detail="URL is required for custom MCP servers",
|
raise HTTPException(
|
||||||
)
|
status_code=400,
|
||||||
|
detail="URL is required for custom MCP servers",
|
||||||
# Add the custom server
|
|
||||||
processed_custom_servers.append(
|
|
||||||
{"url": server["url"], "headers": server.get("headers", {})}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
config["custom_mcp_servers"] = processed_custom_servers
|
# Add the custom server
|
||||||
else:
|
processed_custom_servers.append(
|
||||||
config["custom_mcp_servers"] = []
|
{"url": server["url"], "headers": server.get("headers", {})}
|
||||||
|
)
|
||||||
|
|
||||||
|
processed_config["custom_mcp_servers"] = processed_custom_servers
|
||||||
|
|
||||||
# Process sub-agents
|
# Process sub-agents
|
||||||
if "sub_agents" in config:
|
if "sub_agents" in config and config["sub_agents"] is not None:
|
||||||
if config["sub_agents"] is not None:
|
processed_config["sub_agents"] = [
|
||||||
config["sub_agents"] = [
|
str(agent_id) for agent_id in config["sub_agents"]
|
||||||
str(agent_id) for agent_id in config["sub_agents"]
|
]
|
||||||
]
|
|
||||||
|
|
||||||
# Process tools
|
# Process tools
|
||||||
if "tools" in config:
|
if "tools" in config and config["tools"] is not None:
|
||||||
if config["tools"] is not None:
|
processed_tools = []
|
||||||
config["tools"] = [
|
for tool in config["tools"]:
|
||||||
{"id": str(tool["id"]), "envs": tool["envs"]}
|
# Convert tool id to string
|
||||||
for tool in config["tools"]
|
tool_id = tool["id"]
|
||||||
]
|
|
||||||
|
|
||||||
agent_data["config"] = config
|
# Validar envs para garantir que não é None
|
||||||
|
envs = tool.get("envs", {})
|
||||||
|
if envs is None:
|
||||||
|
envs = {}
|
||||||
|
|
||||||
|
processed_tools.append({"id": str(tool_id), "envs": envs})
|
||||||
|
processed_config["tools"] = processed_tools
|
||||||
|
|
||||||
|
agent_data["config"] = processed_config
|
||||||
|
|
||||||
# Ensure all config objects are serializable (convert UUIDs to strings)
|
# Ensure all config objects are serializable (convert UUIDs to strings)
|
||||||
if "config" in agent_data and agent_data["config"] is not None:
|
if "config" in agent_data and agent_data["config"] is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user