refactor(a2a): enhance A2A request processing and task management with session ID handling

This commit is contained in:
Davidson Gomes 2025-05-05 19:15:14 -03:00
parent 8f1fef71a5
commit a0f984ae21
4 changed files with 352 additions and 175 deletions

View File

@ -124,7 +124,7 @@ def get_task_manager(agent_id, db=None, reuse=True, operation_type="query"):
return task_manager
@router.post("/{agent_id}/rpc")
@router.post("/{agent_id}")
async def process_a2a_request(
agent_id: uuid.UUID,
request: Request,
@ -156,6 +156,13 @@ async def process_a2a_request(
try:
body = await request.json()
method = body.get("method", "unknown")
request_id = body.get("id") # Extract request ID to ensure it's preserved
# Extrair sessionId do params se for tasks/send
session_id = None
if method == "tasks/send" and body.get("params"):
session_id = body.get("params", {}).get("sessionId")
logger.info(f"Extracted sessionId from request: {session_id}")
is_query_request = method in [
"tasks/get",
@ -189,7 +196,7 @@ async def process_a2a_request(
status_code=404,
content={
"jsonrpc": "2.0",
"id": None,
"id": request_id, # Use the extracted request ID
"error": {"code": 404, "message": "Agent not found", "data": None},
},
)
@ -203,7 +210,7 @@ async def process_a2a_request(
status_code=401,
content={
"jsonrpc": "2.0",
"id": None,
"id": request_id, # Use the extracted request ID
"error": {"code": 401, "message": "Invalid API key", "data": None},
},
)
@ -226,7 +233,7 @@ async def process_a2a_request(
status_code=400,
content={
"jsonrpc": "2.0",
"id": body.get("id"),
"id": request_id, # Use the extracted request ID
"error": {
"code": -32600,
"message": "Invalid Request: jsonrpc must be '2.0'",
@ -241,7 +248,7 @@ async def process_a2a_request(
status_code=400,
content={
"jsonrpc": "2.0",
"id": body.get("id"),
"id": request_id, # Use the extracted request ID
"error": {
"code": -32600,
"message": "Invalid Request: method is required",
@ -250,15 +257,24 @@ async def process_a2a_request(
},
)
# Processar a requisição normalmente
return await a2a_server.process_request(request, agent_id=str(agent_id), db=db)
except Exception as e:
logger.error(f"Error processing A2A request: {str(e)}", exc_info=True)
# Try to extract request ID from the body, if available
request_id = None
try:
body = await request.json()
request_id = body.get("id")
except:
pass
return JSONResponse(
status_code=500,
content={
"jsonrpc": "2.0",
"id": None,
"id": request_id, # Use the extracted request ID or None
"error": {
"code": -32603,
"message": "Internal server error",

View File

@ -88,6 +88,7 @@ class AgentRunnerAdapter:
try:
# Use the existing agent runner function
# Usar o session_id fornecido, ou gerar um novo
session_id = session_id or str(uuid.uuid4())
task_id = task_id or str(uuid.uuid4())
@ -105,12 +106,28 @@ class AgentRunnerAdapter:
session_id=session_id,
)
# Format the response to include both the A2A-compatible message
# and the artifact format to match the Google A2A implementation
# Nota: O formato dos artifacts é simplificado para compatibilidade com Google A2A
message_obj = {
"role": "agent",
"parts": [{"type": "text", "text": response_text}],
}
# Formato de artefato compatível com Google A2A
artifact_obj = {
"parts": [{"type": "text", "text": response_text}],
"index": 0,
}
return {
"status": "success",
"content": response_text,
"session_id": session_id,
"task_id": task_id,
"timestamp": datetime.now().isoformat(),
"message": message_obj,
"artifact": artifact_obj,
}
except Exception as e:
@ -121,6 +138,10 @@ class AgentRunnerAdapter:
"session_id": session_id,
"task_id": task_id,
"timestamp": datetime.now().isoformat(),
"message": {
"role": "agent",
"parts": [{"type": "text", "text": f"Error: {str(e)}"}],
},
}
async def cancel_task(self, task_id: str) -> bool:

View File

@ -9,6 +9,7 @@ import asyncio
import logging
from datetime import datetime
from typing import Any, Dict, Union, AsyncIterable
import uuid
from src.schemas.a2a.exceptions import (
TaskNotFoundError,
@ -263,8 +264,9 @@ class A2ATaskManager:
f"task_notification:{task_id}", params.pushNotification.model_dump()
)
# Start task execution in background
asyncio.create_task(self._execute_task(task_data, params))
# Execute task SYNCHRONOUSLY instead of in background
# This is the key change for A2A compatibility
task_data = await self._execute_task(task_data, params)
# Convert to Task object and return
task = Task.model_validate(task_data)
@ -523,7 +525,8 @@ class A2ATaskManager:
# Create task with initial status
task_data = {
"id": params.id,
"sessionId": params.sessionId,
"sessionId": params.sessionId
or str(uuid.uuid4()), # Preservar sessionId quando fornecido
"status": {
"state": TaskState.SUBMITTED,
"timestamp": datetime.now().isoformat(),
@ -531,7 +534,7 @@ class A2ATaskManager:
"error": None,
},
"artifacts": [],
"history": [params.message.model_dump()],
"history": [params.message.model_dump()], # Apenas mensagem do usuário
"metadata": params.metadata or {},
}
@ -540,7 +543,9 @@ class A2ATaskManager:
return task_data
async def _execute_task(self, task: Dict[str, Any], params: TaskSendParams) -> None:
async def _execute_task(
self, task: Dict[str, Any], params: TaskSendParams
) -> Dict[str, Any]:
"""
Execute a task using the agent adapter.
@ -550,6 +555,9 @@ class A2ATaskManager:
Args:
task: Task data to be executed
params: Send task parameters
Returns:
Updated task data with completed status and response
"""
task_id = task["id"]
agent_id = params.agentId
@ -562,20 +570,22 @@ class A2ATaskManager:
message_text += part.text
if not message_text:
await self._update_task_status(
await self._update_task_status_without_history(
task_id, TaskState.FAILED, "Message does not contain text", final=True
)
return
# Return the updated task data
return await self.redis_cache.get(f"task:{task_id}")
# Check if it is an ongoing execution
task_status = task.get("status", {})
if task_status.get("state") in [TaskState.WORKING, TaskState.COMPLETED]:
logger.info(f"Task {task_id} is already in execution or completed")
return
# Return the current task data
return await self.redis_cache.get(f"task:{task_id}")
try:
# Update to "working" state
await self._update_task_status(
# Update to "working" state - NÃO adicionar ao histórico
await self._update_task_status_without_history(
task_id, TaskState.WORKING, "Processing request"
)
@ -584,7 +594,7 @@ class A2ATaskManager:
response = await self.agent_runner.run_agent(
agent_id=agent_id,
message=message_text,
session_id=params.sessionId,
session_id=params.sessionId, # Usar o sessionId da requisição
task_id=task_id,
)
@ -601,37 +611,38 @@ class A2ATaskManager:
# Build the final agent message
if response_text:
# Create an artifact for the response
artifact = Artifact(
name="response",
parts=[TextPart(text=response_text)],
index=0,
lastChunk=True,
)
# Atualizar o histórico com a mensagem do usuário
await self._update_task_history(task_id, params.message)
# Create an artifact for the response in Google A2A format
artifact = {
"parts": [{"type": "text", "text": response_text}],
"index": 0,
}
# Add the artifact to the task
await self._add_task_artifact(task_id, artifact)
# Update the task status to completed
await self._update_task_status(
# Update the task status to completed (sem adicionar ao histórico)
await self._update_task_status_without_history(
task_id, TaskState.COMPLETED, response_text, final=True
)
else:
await self._update_task_status(
await self._update_task_status_without_history(
task_id,
TaskState.FAILED,
"The agent did not return a valid response",
final=True,
)
else:
await self._update_task_status(
await self._update_task_status_without_history(
task_id,
TaskState.FAILED,
"Invalid agent response",
final=True,
)
else:
await self._update_task_status(
await self._update_task_status_without_history(
task_id,
TaskState.FAILED,
"Agent adapter not configured",
@ -639,15 +650,78 @@ class A2ATaskManager:
)
except Exception as e:
logger.error(f"Error executing task {task_id}: {str(e)}")
await self._update_task_status(
await self._update_task_status_without_history(
task_id, TaskState.FAILED, f"Error processing: {str(e)}", final=True
)
async def _update_task_status(
# Return the updated task data
return await self.redis_cache.get(f"task:{task_id}")
async def _update_task_history(self, task_id: str, message) -> None:
"""
Atualiza o histórico da tarefa incluindo apenas mensagens do usuário.
Args:
task_id: ID da tarefa
message: Mensagem do usuário para adicionar ao histórico
"""
if not message:
return
# Obter dados da tarefa atual
task_data = await self.redis_cache.get(f"task:{task_id}")
if not task_data:
logger.warning(f"Task {task_id} not found for history update")
return
# Garantir que há um campo de histórico
if "history" not in task_data:
task_data["history"] = []
# Verificar se a mensagem já existe no histórico para evitar duplicação
user_message = (
message.model_dump() if hasattr(message, "model_dump") else message
)
message_exists = False
for msg in task_data["history"]:
if self._compare_messages(msg, user_message):
message_exists = True
break
# Adicionar mensagem se não existir
if not message_exists:
task_data["history"].append(user_message)
logger.info(f"Added new user message to history for task {task_id}")
# Salvar tarefa atualizada
await self.redis_cache.set(f"task:{task_id}", task_data)
# Método auxiliar para comparar mensagens e evitar duplicação no histórico
def _compare_messages(self, msg1: Dict, msg2: Dict) -> bool:
"""Compara duas mensagens para verificar se são essencialmente iguais."""
if msg1.get("role") != msg2.get("role"):
return False
parts1 = msg1.get("parts", [])
parts2 = msg2.get("parts", [])
if len(parts1) != len(parts2):
return False
for i in range(len(parts1)):
if parts1[i].get("type") != parts2[i].get("type"):
return False
if parts1[i].get("text") != parts2[i].get("text"):
return False
return True
async def _update_task_status_without_history(
self, task_id: str, state: TaskState, message_text: str, final: bool = False
) -> None:
"""
Update the status of a task.
Update the status of a task without changing the history.
Args:
task_id: ID of the task to be updated
@ -666,7 +740,6 @@ class A2ATaskManager:
agent_message = Message(
role="agent",
parts=[TextPart(text=message_text)],
metadata={"timestamp": datetime.now().isoformat()},
)
status = TaskStatus(
@ -676,13 +749,6 @@ class A2ATaskManager:
# Update the status in the task
task_data["status"] = status.model_dump(exclude_none=True)
# Update the history, if it exists
if "history" not in task_data:
task_data["history"] = []
# Add the message to the history
task_data["history"].append(agent_message.model_dump(exclude_none=True))
# Store the updated task
await self.redis_cache.set(f"task:{task_id}", task_data)
@ -704,13 +770,13 @@ class A2ATaskManager:
except Exception as e:
logger.error(f"Error updating task status {task_id}: {str(e)}")
async def _add_task_artifact(self, task_id: str, artifact: Artifact) -> None:
async def _add_task_artifact(self, task_id: str, artifact) -> None:
"""
Add an artifact to a task and publish the update.
Args:
task_id: Task ID
artifact: Artifact to add
artifact: Artifact to add (dict no formato do Google)
"""
logger.info(f"Adding artifact to task {task_id}")
@ -720,13 +786,21 @@ class A2ATaskManager:
if "artifacts" not in task_data:
task_data["artifacts"] = []
# Convert artifact to dict
artifact_dict = artifact.model_dump()
task_data["artifacts"].append(artifact_dict)
# Adicionar o artefato sem substituir os existentes
task_data["artifacts"].append(artifact)
await self.redis_cache.set(f"task:{task_id}", task_data)
# Criar um artefato do tipo Artifact para o evento
artifact_obj = Artifact(
parts=[
TextPart(text=part.get("text", ""))
for part in artifact.get("parts", [])
],
index=artifact.get("index", 0),
)
# Create artifact update event
event = TaskArtifactUpdateEvent(id=task_id, artifact=artifact)
event = TaskArtifactUpdateEvent(id=task_id, artifact=artifact_obj)
# Publish event
await self._publish_task_update(task_id, event)

View File

@ -189,83 +189,114 @@ async def create_agent(db: Session, agent: AgentCreate) -> Agent:
logger.info("Generating automatic API key for new agent")
config["api_key"] = generate_api_key()
if isinstance(config, dict):
# Process MCP servers
if "mcp_servers" in config:
if config["mcp_servers"] is not None:
processed_servers = []
for server in config["mcp_servers"]:
# Convert server id to UUID if it's a string
server_id = server["id"]
if isinstance(server_id, str):
server_id = uuid.UUID(server_id)
# Preservar todos os campos originais
processed_config = {}
processed_config["api_key"] = config.get("api_key", "")
# Search for MCP server in the database
mcp_server = get_mcp_server(db, server_id)
if not mcp_server:
raise HTTPException(
status_code=400,
detail=f"MCP server not found: {server['id']}",
)
# Copiar campos originais
if "tools" in config:
processed_config["tools"] = config["tools"]
# Check if all required environment variables are provided
for env_key, env_value in mcp_server.environments.items():
if env_key not in server.get("envs", {}):
raise HTTPException(
status_code=400,
detail=f"Environment variable '{env_key}' not provided for MCP server {mcp_server.name}",
)
if "custom_tools" in config:
processed_config["custom_tools"] = config["custom_tools"]
# Add the processed server with its tools
processed_servers.append(
{
"id": str(server["id"]),
"envs": server["envs"],
"tools": server["tools"],
}
if "sub_agents" in config:
processed_config["sub_agents"] = config["sub_agents"]
if "custom_mcp_servers" in config:
processed_config["custom_mcp_servers"] = config["custom_mcp_servers"]
# Preservar outros campos não processados especificamente
for key, value in config.items():
if key not in [
"api_key",
"tools",
"custom_tools",
"sub_agents",
"custom_mcp_servers",
"mcp_servers",
]:
processed_config[key] = value
# Processar apenas campos que precisam de processamento
# Process MCP servers
if "mcp_servers" in config and config["mcp_servers"] is not None:
processed_servers = []
for server in config["mcp_servers"]:
# Convert server id to UUID if it's a string
server_id = server["id"]
if isinstance(server_id, str):
server_id = uuid.UUID(server_id)
# Search for MCP server in the database
mcp_server = get_mcp_server(db, server_id)
if not mcp_server:
raise HTTPException(
status_code=400,
detail=f"MCP server not found: {server['id']}",
)
# Check if all required environment variables are provided
for env_key, env_value in mcp_server.environments.items():
if env_key not in server.get("envs", {}):
raise HTTPException(
status_code=400,
detail=f"Environment variable '{env_key}' not provided for MCP server {mcp_server.name}",
)
config["mcp_servers"] = processed_servers
else:
config["mcp_servers"] = []
# Add the processed server
processed_servers.append(
{
"id": str(server["id"]),
"envs": server["envs"],
"tools": server["tools"],
}
)
# Process custom MCP servers
if "custom_mcp_servers" in config:
if config["custom_mcp_servers"] is not None:
processed_custom_servers = []
for server in config["custom_mcp_servers"]:
# Validate URL format
if not server.get("url"):
raise HTTPException(
status_code=400,
detail="URL is required for custom MCP servers",
)
processed_config["mcp_servers"] = processed_servers
elif "mcp_servers" in config:
processed_config["mcp_servers"] = config["mcp_servers"]
# Add the custom server
processed_custom_servers.append(
{"url": server["url"], "headers": server.get("headers", {})}
)
# Process custom MCP servers
if "custom_mcp_servers" in config and config["custom_mcp_servers"] is not None:
processed_custom_servers = []
for server in config["custom_mcp_servers"]:
# Validate URL format
if not server.get("url"):
raise HTTPException(
status_code=400,
detail="URL is required for custom MCP servers",
)
config["custom_mcp_servers"] = processed_custom_servers
else:
config["custom_mcp_servers"] = []
# Add the custom server
processed_custom_servers.append(
{"url": server["url"], "headers": server.get("headers", {})}
)
# Process sub-agents
if "sub_agents" in config:
if config["sub_agents"] is not None:
config["sub_agents"] = [
str(agent_id) for agent_id in config["sub_agents"]
]
processed_config["custom_mcp_servers"] = processed_custom_servers
# Process tools
if "tools" in config:
if config["tools"] is not None:
config["tools"] = [
{"id": str(tool["id"]), "envs": tool["envs"]}
for tool in config["tools"]
]
# Process sub-agents
if "sub_agents" in config and config["sub_agents"] is not None:
processed_config["sub_agents"] = [
str(agent_id) for agent_id in config["sub_agents"]
]
agent.config = config
# Process tools
if "tools" in config and config["tools"] is not None:
processed_tools = []
for tool in config["tools"]:
# Convert tool id to string
tool_id = tool["id"]
# Validar envs para garantir que não é None
envs = tool.get("envs", {})
if envs is None:
envs = {}
processed_tools.append({"id": str(tool_id), "envs": envs})
processed_config["tools"] = processed_tools
agent.config = processed_config
# Ensure all config objects are serializable (convert UUIDs to strings)
if agent.config is not None:
@ -398,82 +429,117 @@ async def update_agent(
if "config" in agent_data:
config = agent_data["config"]
# Preservar todos os campos originais
processed_config = {}
processed_config["api_key"] = config.get("api_key", "")
# Copiar campos originais
if "tools" in config:
processed_config["tools"] = config["tools"]
if "custom_tools" in config:
processed_config["custom_tools"] = config["custom_tools"]
if "sub_agents" in config:
processed_config["sub_agents"] = config["sub_agents"]
if "custom_mcp_servers" in config:
processed_config["custom_mcp_servers"] = config["custom_mcp_servers"]
# Preservar outros campos não processados especificamente
for key, value in config.items():
if key not in [
"api_key",
"tools",
"custom_tools",
"sub_agents",
"custom_mcp_servers",
"mcp_servers",
]:
processed_config[key] = value
# Processar apenas campos que precisam de processamento
# Process MCP servers
if "mcp_servers" in config:
if config["mcp_servers"] is not None:
processed_servers = []
for server in config["mcp_servers"]:
# Convert server id to UUID if it's a string
server_id = server["id"]
if isinstance(server_id, str):
server_id = uuid.UUID(server_id)
if "mcp_servers" in config and config["mcp_servers"] is not None:
processed_servers = []
for server in config["mcp_servers"]:
# Convert server id to UUID if it's a string
server_id = server["id"]
if isinstance(server_id, str):
server_id = uuid.UUID(server_id)
# Search for MCP server in the database
mcp_server = get_mcp_server(db, server_id)
if not mcp_server:
raise HTTPException(
status_code=400,
detail=f"MCP server not found: {server['id']}",
)
# Check if all required environment variables are provided
for env_key, env_value in mcp_server.environments.items():
if env_key not in server.get("envs", {}):
raise HTTPException(
status_code=400,
detail=f"Environment variable '{env_key}' not provided for MCP server {mcp_server.name}",
)
# Add the processed server
processed_servers.append(
{
"id": str(server["id"]),
"envs": server["envs"],
"tools": server["tools"],
}
# Search for MCP server in the database
mcp_server = get_mcp_server(db, server_id)
if not mcp_server:
raise HTTPException(
status_code=400,
detail=f"MCP server not found: {server['id']}",
)
config["mcp_servers"] = processed_servers
else:
config["mcp_servers"] = []
# Check if all required environment variables are provided
for env_key, env_value in mcp_server.environments.items():
if env_key not in server.get("envs", {}):
raise HTTPException(
status_code=400,
detail=f"Environment variable '{env_key}' not provided for MCP server {mcp_server.name}",
)
# Add the processed server
processed_servers.append(
{
"id": str(server["id"]),
"envs": server["envs"],
"tools": server["tools"],
}
)
processed_config["mcp_servers"] = processed_servers
elif "mcp_servers" in config:
processed_config["mcp_servers"] = config["mcp_servers"]
# Process custom MCP servers
if "custom_mcp_servers" in config:
if config["custom_mcp_servers"] is not None:
processed_custom_servers = []
for server in config["custom_mcp_servers"]:
# Validate URL format
if not server.get("url"):
raise HTTPException(
status_code=400,
detail="URL is required for custom MCP servers",
)
# Add the custom server
processed_custom_servers.append(
{"url": server["url"], "headers": server.get("headers", {})}
if (
"custom_mcp_servers" in config
and config["custom_mcp_servers"] is not None
):
processed_custom_servers = []
for server in config["custom_mcp_servers"]:
# Validate URL format
if not server.get("url"):
raise HTTPException(
status_code=400,
detail="URL is required for custom MCP servers",
)
config["custom_mcp_servers"] = processed_custom_servers
else:
config["custom_mcp_servers"] = []
# Add the custom server
processed_custom_servers.append(
{"url": server["url"], "headers": server.get("headers", {})}
)
processed_config["custom_mcp_servers"] = processed_custom_servers
# Process sub-agents
if "sub_agents" in config:
if config["sub_agents"] is not None:
config["sub_agents"] = [
str(agent_id) for agent_id in config["sub_agents"]
]
if "sub_agents" in config and config["sub_agents"] is not None:
processed_config["sub_agents"] = [
str(agent_id) for agent_id in config["sub_agents"]
]
# Process tools
if "tools" in config:
if config["tools"] is not None:
config["tools"] = [
{"id": str(tool["id"]), "envs": tool["envs"]}
for tool in config["tools"]
]
if "tools" in config and config["tools"] is not None:
processed_tools = []
for tool in config["tools"]:
# Convert tool id to string
tool_id = tool["id"]
agent_data["config"] = config
# Validar envs para garantir que não é None
envs = tool.get("envs", {})
if envs is None:
envs = {}
processed_tools.append({"id": str(tool_id), "envs": envs})
processed_config["tools"] = processed_tools
agent_data["config"] = processed_config
# Ensure all config objects are serializable (convert UUIDs to strings)
if "config" in agent_data and agent_data["config"] is not None: