From a0f984ae21c4d5d68d94d918e1f2e9aedf8936af Mon Sep 17 00:00:00 2001 From: Davidson Gomes Date: Mon, 5 May 2025 19:15:14 -0300 Subject: [PATCH] refactor(a2a): enhance A2A request processing and task management with session ID handling --- src/api/a2a_routes.py | 28 +- src/services/a2a_integration_service.py | 21 ++ src/services/a2a_task_manager_service.py | 154 ++++++++--- src/services/agent_service.py | 324 ++++++++++++++--------- 4 files changed, 352 insertions(+), 175 deletions(-) diff --git a/src/api/a2a_routes.py b/src/api/a2a_routes.py index bbf17dd0..92413356 100644 --- a/src/api/a2a_routes.py +++ b/src/api/a2a_routes.py @@ -124,7 +124,7 @@ def get_task_manager(agent_id, db=None, reuse=True, operation_type="query"): return task_manager -@router.post("/{agent_id}/rpc") +@router.post("/{agent_id}") async def process_a2a_request( agent_id: uuid.UUID, request: Request, @@ -156,6 +156,13 @@ async def process_a2a_request( try: body = await request.json() method = body.get("method", "unknown") + request_id = body.get("id") # Extract request ID to ensure it's preserved + + # Extrair sessionId do params se for tasks/send + session_id = None + if method == "tasks/send" and body.get("params"): + session_id = body.get("params", {}).get("sessionId") + logger.info(f"Extracted sessionId from request: {session_id}") is_query_request = method in [ "tasks/get", @@ -189,7 +196,7 @@ async def process_a2a_request( status_code=404, content={ "jsonrpc": "2.0", - "id": None, + "id": request_id, # Use the extracted request ID "error": {"code": 404, "message": "Agent not found", "data": None}, }, ) @@ -203,7 +210,7 @@ async def process_a2a_request( status_code=401, content={ "jsonrpc": "2.0", - "id": None, + "id": request_id, # Use the extracted request ID "error": {"code": 401, "message": "Invalid API key", "data": None}, }, ) @@ -226,7 +233,7 @@ async def process_a2a_request( status_code=400, content={ "jsonrpc": "2.0", - "id": body.get("id"), + "id": request_id, # Use the extracted request ID "error": { "code": -32600, "message": "Invalid Request: jsonrpc must be '2.0'", @@ -241,7 +248,7 @@ async def process_a2a_request( status_code=400, content={ "jsonrpc": "2.0", - "id": body.get("id"), + "id": request_id, # Use the extracted request ID "error": { "code": -32600, "message": "Invalid Request: method is required", @@ -250,15 +257,24 @@ async def process_a2a_request( }, ) + # Processar a requisição normalmente return await a2a_server.process_request(request, agent_id=str(agent_id), db=db) except Exception as e: logger.error(f"Error processing A2A request: {str(e)}", exc_info=True) + # Try to extract request ID from the body, if available + request_id = None + try: + body = await request.json() + request_id = body.get("id") + except: + pass + return JSONResponse( status_code=500, content={ "jsonrpc": "2.0", - "id": None, + "id": request_id, # Use the extracted request ID or None "error": { "code": -32603, "message": "Internal server error", diff --git a/src/services/a2a_integration_service.py b/src/services/a2a_integration_service.py index f42d2bcb..17a386b9 100644 --- a/src/services/a2a_integration_service.py +++ b/src/services/a2a_integration_service.py @@ -88,6 +88,7 @@ class AgentRunnerAdapter: try: # Use the existing agent runner function + # Usar o session_id fornecido, ou gerar um novo session_id = session_id or str(uuid.uuid4()) task_id = task_id or str(uuid.uuid4()) @@ -105,12 +106,28 @@ class AgentRunnerAdapter: session_id=session_id, ) + # Format the response to include both the A2A-compatible message + # and the artifact format to match the Google A2A implementation + # Nota: O formato dos artifacts é simplificado para compatibilidade com Google A2A + message_obj = { + "role": "agent", + "parts": [{"type": "text", "text": response_text}], + } + + # Formato de artefato compatível com Google A2A + artifact_obj = { + "parts": [{"type": "text", "text": response_text}], + "index": 0, + } + return { "status": "success", "content": response_text, "session_id": session_id, "task_id": task_id, "timestamp": datetime.now().isoformat(), + "message": message_obj, + "artifact": artifact_obj, } except Exception as e: @@ -121,6 +138,10 @@ class AgentRunnerAdapter: "session_id": session_id, "task_id": task_id, "timestamp": datetime.now().isoformat(), + "message": { + "role": "agent", + "parts": [{"type": "text", "text": f"Error: {str(e)}"}], + }, } async def cancel_task(self, task_id: str) -> bool: diff --git a/src/services/a2a_task_manager_service.py b/src/services/a2a_task_manager_service.py index 6d5b926a..1857949c 100644 --- a/src/services/a2a_task_manager_service.py +++ b/src/services/a2a_task_manager_service.py @@ -9,6 +9,7 @@ import asyncio import logging from datetime import datetime from typing import Any, Dict, Union, AsyncIterable +import uuid from src.schemas.a2a.exceptions import ( TaskNotFoundError, @@ -263,8 +264,9 @@ class A2ATaskManager: f"task_notification:{task_id}", params.pushNotification.model_dump() ) - # Start task execution in background - asyncio.create_task(self._execute_task(task_data, params)) + # Execute task SYNCHRONOUSLY instead of in background + # This is the key change for A2A compatibility + task_data = await self._execute_task(task_data, params) # Convert to Task object and return task = Task.model_validate(task_data) @@ -523,7 +525,8 @@ class A2ATaskManager: # Create task with initial status task_data = { "id": params.id, - "sessionId": params.sessionId, + "sessionId": params.sessionId + or str(uuid.uuid4()), # Preservar sessionId quando fornecido "status": { "state": TaskState.SUBMITTED, "timestamp": datetime.now().isoformat(), @@ -531,7 +534,7 @@ class A2ATaskManager: "error": None, }, "artifacts": [], - "history": [params.message.model_dump()], + "history": [params.message.model_dump()], # Apenas mensagem do usuário "metadata": params.metadata or {}, } @@ -540,7 +543,9 @@ class A2ATaskManager: return task_data - async def _execute_task(self, task: Dict[str, Any], params: TaskSendParams) -> None: + async def _execute_task( + self, task: Dict[str, Any], params: TaskSendParams + ) -> Dict[str, Any]: """ Execute a task using the agent adapter. @@ -550,6 +555,9 @@ class A2ATaskManager: Args: task: Task data to be executed params: Send task parameters + + Returns: + Updated task data with completed status and response """ task_id = task["id"] agent_id = params.agentId @@ -562,20 +570,22 @@ class A2ATaskManager: message_text += part.text if not message_text: - await self._update_task_status( + await self._update_task_status_without_history( task_id, TaskState.FAILED, "Message does not contain text", final=True ) - return + # Return the updated task data + return await self.redis_cache.get(f"task:{task_id}") # Check if it is an ongoing execution task_status = task.get("status", {}) if task_status.get("state") in [TaskState.WORKING, TaskState.COMPLETED]: logger.info(f"Task {task_id} is already in execution or completed") - return + # Return the current task data + return await self.redis_cache.get(f"task:{task_id}") try: - # Update to "working" state - await self._update_task_status( + # Update to "working" state - NÃO adicionar ao histórico + await self._update_task_status_without_history( task_id, TaskState.WORKING, "Processing request" ) @@ -584,7 +594,7 @@ class A2ATaskManager: response = await self.agent_runner.run_agent( agent_id=agent_id, message=message_text, - session_id=params.sessionId, + session_id=params.sessionId, # Usar o sessionId da requisição task_id=task_id, ) @@ -601,37 +611,38 @@ class A2ATaskManager: # Build the final agent message if response_text: - # Create an artifact for the response - artifact = Artifact( - name="response", - parts=[TextPart(text=response_text)], - index=0, - lastChunk=True, - ) + # Atualizar o histórico com a mensagem do usuário + await self._update_task_history(task_id, params.message) + + # Create an artifact for the response in Google A2A format + artifact = { + "parts": [{"type": "text", "text": response_text}], + "index": 0, + } # Add the artifact to the task await self._add_task_artifact(task_id, artifact) - # Update the task status to completed - await self._update_task_status( + # Update the task status to completed (sem adicionar ao histórico) + await self._update_task_status_without_history( task_id, TaskState.COMPLETED, response_text, final=True ) else: - await self._update_task_status( + await self._update_task_status_without_history( task_id, TaskState.FAILED, "The agent did not return a valid response", final=True, ) else: - await self._update_task_status( + await self._update_task_status_without_history( task_id, TaskState.FAILED, "Invalid agent response", final=True, ) else: - await self._update_task_status( + await self._update_task_status_without_history( task_id, TaskState.FAILED, "Agent adapter not configured", @@ -639,15 +650,78 @@ class A2ATaskManager: ) except Exception as e: logger.error(f"Error executing task {task_id}: {str(e)}") - await self._update_task_status( + await self._update_task_status_without_history( task_id, TaskState.FAILED, f"Error processing: {str(e)}", final=True ) - async def _update_task_status( + # Return the updated task data + return await self.redis_cache.get(f"task:{task_id}") + + async def _update_task_history(self, task_id: str, message) -> None: + """ + Atualiza o histórico da tarefa incluindo apenas mensagens do usuário. + + Args: + task_id: ID da tarefa + message: Mensagem do usuário para adicionar ao histórico + """ + if not message: + return + + # Obter dados da tarefa atual + task_data = await self.redis_cache.get(f"task:{task_id}") + if not task_data: + logger.warning(f"Task {task_id} not found for history update") + return + + # Garantir que há um campo de histórico + if "history" not in task_data: + task_data["history"] = [] + + # Verificar se a mensagem já existe no histórico para evitar duplicação + user_message = ( + message.model_dump() if hasattr(message, "model_dump") else message + ) + message_exists = False + + for msg in task_data["history"]: + if self._compare_messages(msg, user_message): + message_exists = True + break + + # Adicionar mensagem se não existir + if not message_exists: + task_data["history"].append(user_message) + logger.info(f"Added new user message to history for task {task_id}") + + # Salvar tarefa atualizada + await self.redis_cache.set(f"task:{task_id}", task_data) + + # Método auxiliar para comparar mensagens e evitar duplicação no histórico + def _compare_messages(self, msg1: Dict, msg2: Dict) -> bool: + """Compara duas mensagens para verificar se são essencialmente iguais.""" + if msg1.get("role") != msg2.get("role"): + return False + + parts1 = msg1.get("parts", []) + parts2 = msg2.get("parts", []) + + if len(parts1) != len(parts2): + return False + + for i in range(len(parts1)): + if parts1[i].get("type") != parts2[i].get("type"): + return False + if parts1[i].get("text") != parts2[i].get("text"): + return False + + return True + + async def _update_task_status_without_history( self, task_id: str, state: TaskState, message_text: str, final: bool = False ) -> None: """ - Update the status of a task. + Update the status of a task without changing the history. Args: task_id: ID of the task to be updated @@ -666,7 +740,6 @@ class A2ATaskManager: agent_message = Message( role="agent", parts=[TextPart(text=message_text)], - metadata={"timestamp": datetime.now().isoformat()}, ) status = TaskStatus( @@ -676,13 +749,6 @@ class A2ATaskManager: # Update the status in the task task_data["status"] = status.model_dump(exclude_none=True) - # Update the history, if it exists - if "history" not in task_data: - task_data["history"] = [] - - # Add the message to the history - task_data["history"].append(agent_message.model_dump(exclude_none=True)) - # Store the updated task await self.redis_cache.set(f"task:{task_id}", task_data) @@ -704,13 +770,13 @@ class A2ATaskManager: except Exception as e: logger.error(f"Error updating task status {task_id}: {str(e)}") - async def _add_task_artifact(self, task_id: str, artifact: Artifact) -> None: + async def _add_task_artifact(self, task_id: str, artifact) -> None: """ Add an artifact to a task and publish the update. Args: task_id: Task ID - artifact: Artifact to add + artifact: Artifact to add (dict no formato do Google) """ logger.info(f"Adding artifact to task {task_id}") @@ -720,13 +786,21 @@ class A2ATaskManager: if "artifacts" not in task_data: task_data["artifacts"] = [] - # Convert artifact to dict - artifact_dict = artifact.model_dump() - task_data["artifacts"].append(artifact_dict) + # Adicionar o artefato sem substituir os existentes + task_data["artifacts"].append(artifact) await self.redis_cache.set(f"task:{task_id}", task_data) + # Criar um artefato do tipo Artifact para o evento + artifact_obj = Artifact( + parts=[ + TextPart(text=part.get("text", "")) + for part in artifact.get("parts", []) + ], + index=artifact.get("index", 0), + ) + # Create artifact update event - event = TaskArtifactUpdateEvent(id=task_id, artifact=artifact) + event = TaskArtifactUpdateEvent(id=task_id, artifact=artifact_obj) # Publish event await self._publish_task_update(task_id, event) diff --git a/src/services/agent_service.py b/src/services/agent_service.py index 5d053eaa..2f5775cd 100644 --- a/src/services/agent_service.py +++ b/src/services/agent_service.py @@ -189,83 +189,114 @@ async def create_agent(db: Session, agent: AgentCreate) -> Agent: logger.info("Generating automatic API key for new agent") config["api_key"] = generate_api_key() - if isinstance(config, dict): - # Process MCP servers - if "mcp_servers" in config: - if config["mcp_servers"] is not None: - processed_servers = [] - for server in config["mcp_servers"]: - # Convert server id to UUID if it's a string - server_id = server["id"] - if isinstance(server_id, str): - server_id = uuid.UUID(server_id) + # Preservar todos os campos originais + processed_config = {} + processed_config["api_key"] = config.get("api_key", "") - # Search for MCP server in the database - mcp_server = get_mcp_server(db, server_id) - if not mcp_server: - raise HTTPException( - status_code=400, - detail=f"MCP server not found: {server['id']}", - ) + # Copiar campos originais + if "tools" in config: + processed_config["tools"] = config["tools"] - # Check if all required environment variables are provided - for env_key, env_value in mcp_server.environments.items(): - if env_key not in server.get("envs", {}): - raise HTTPException( - status_code=400, - detail=f"Environment variable '{env_key}' not provided for MCP server {mcp_server.name}", - ) + if "custom_tools" in config: + processed_config["custom_tools"] = config["custom_tools"] - # Add the processed server with its tools - processed_servers.append( - { - "id": str(server["id"]), - "envs": server["envs"], - "tools": server["tools"], - } + if "sub_agents" in config: + processed_config["sub_agents"] = config["sub_agents"] + + if "custom_mcp_servers" in config: + processed_config["custom_mcp_servers"] = config["custom_mcp_servers"] + + # Preservar outros campos não processados especificamente + for key, value in config.items(): + if key not in [ + "api_key", + "tools", + "custom_tools", + "sub_agents", + "custom_mcp_servers", + "mcp_servers", + ]: + processed_config[key] = value + + # Processar apenas campos que precisam de processamento + # Process MCP servers + if "mcp_servers" in config and config["mcp_servers"] is not None: + processed_servers = [] + for server in config["mcp_servers"]: + # Convert server id to UUID if it's a string + server_id = server["id"] + if isinstance(server_id, str): + server_id = uuid.UUID(server_id) + + # Search for MCP server in the database + mcp_server = get_mcp_server(db, server_id) + if not mcp_server: + raise HTTPException( + status_code=400, + detail=f"MCP server not found: {server['id']}", + ) + + # Check if all required environment variables are provided + for env_key, env_value in mcp_server.environments.items(): + if env_key not in server.get("envs", {}): + raise HTTPException( + status_code=400, + detail=f"Environment variable '{env_key}' not provided for MCP server {mcp_server.name}", ) - config["mcp_servers"] = processed_servers - else: - config["mcp_servers"] = [] + # Add the processed server + processed_servers.append( + { + "id": str(server["id"]), + "envs": server["envs"], + "tools": server["tools"], + } + ) - # Process custom MCP servers - if "custom_mcp_servers" in config: - if config["custom_mcp_servers"] is not None: - processed_custom_servers = [] - for server in config["custom_mcp_servers"]: - # Validate URL format - if not server.get("url"): - raise HTTPException( - status_code=400, - detail="URL is required for custom MCP servers", - ) + processed_config["mcp_servers"] = processed_servers + elif "mcp_servers" in config: + processed_config["mcp_servers"] = config["mcp_servers"] - # Add the custom server - processed_custom_servers.append( - {"url": server["url"], "headers": server.get("headers", {})} - ) + # Process custom MCP servers + if "custom_mcp_servers" in config and config["custom_mcp_servers"] is not None: + processed_custom_servers = [] + for server in config["custom_mcp_servers"]: + # Validate URL format + if not server.get("url"): + raise HTTPException( + status_code=400, + detail="URL is required for custom MCP servers", + ) - config["custom_mcp_servers"] = processed_custom_servers - else: - config["custom_mcp_servers"] = [] + # Add the custom server + processed_custom_servers.append( + {"url": server["url"], "headers": server.get("headers", {})} + ) - # Process sub-agents - if "sub_agents" in config: - if config["sub_agents"] is not None: - config["sub_agents"] = [ - str(agent_id) for agent_id in config["sub_agents"] - ] + processed_config["custom_mcp_servers"] = processed_custom_servers - # Process tools - if "tools" in config: - if config["tools"] is not None: - config["tools"] = [ - {"id": str(tool["id"]), "envs": tool["envs"]} - for tool in config["tools"] - ] + # Process sub-agents + if "sub_agents" in config and config["sub_agents"] is not None: + processed_config["sub_agents"] = [ + str(agent_id) for agent_id in config["sub_agents"] + ] - agent.config = config + # Process tools + if "tools" in config and config["tools"] is not None: + processed_tools = [] + for tool in config["tools"]: + # Convert tool id to string + tool_id = tool["id"] + + # Validar envs para garantir que não é None + envs = tool.get("envs", {}) + if envs is None: + envs = {} + + processed_tools.append({"id": str(tool_id), "envs": envs}) + processed_config["tools"] = processed_tools + + agent.config = processed_config # Ensure all config objects are serializable (convert UUIDs to strings) if agent.config is not None: @@ -398,82 +429,117 @@ async def update_agent( if "config" in agent_data: config = agent_data["config"] + # Preservar todos os campos originais + processed_config = {} + processed_config["api_key"] = config.get("api_key", "") + + # Copiar campos originais + if "tools" in config: + processed_config["tools"] = config["tools"] + + if "custom_tools" in config: + processed_config["custom_tools"] = config["custom_tools"] + + if "sub_agents" in config: + processed_config["sub_agents"] = config["sub_agents"] + + if "custom_mcp_servers" in config: + processed_config["custom_mcp_servers"] = config["custom_mcp_servers"] + + # Preservar outros campos não processados especificamente + for key, value in config.items(): + if key not in [ + "api_key", + "tools", + "custom_tools", + "sub_agents", + "custom_mcp_servers", + "mcp_servers", + ]: + processed_config[key] = value + + # Processar apenas campos que precisam de processamento # Process MCP servers - if "mcp_servers" in config: - if config["mcp_servers"] is not None: - processed_servers = [] - for server in config["mcp_servers"]: - # Convert server id to UUID if it's a string - server_id = server["id"] - if isinstance(server_id, str): - server_id = uuid.UUID(server_id) + if "mcp_servers" in config and config["mcp_servers"] is not None: + processed_servers = [] + for server in config["mcp_servers"]: + # Convert server id to UUID if it's a string + server_id = server["id"] + if isinstance(server_id, str): + server_id = uuid.UUID(server_id) - # Search for MCP server in the database - mcp_server = get_mcp_server(db, server_id) - if not mcp_server: - raise HTTPException( - status_code=400, - detail=f"MCP server not found: {server['id']}", - ) - - # Check if all required environment variables are provided - for env_key, env_value in mcp_server.environments.items(): - if env_key not in server.get("envs", {}): - raise HTTPException( - status_code=400, - detail=f"Environment variable '{env_key}' not provided for MCP server {mcp_server.name}", - ) - - # Add the processed server - processed_servers.append( - { - "id": str(server["id"]), - "envs": server["envs"], - "tools": server["tools"], - } + # Search for MCP server in the database + mcp_server = get_mcp_server(db, server_id) + if not mcp_server: + raise HTTPException( + status_code=400, + detail=f"MCP server not found: {server['id']}", ) - config["mcp_servers"] = processed_servers - else: - config["mcp_servers"] = [] + # Check if all required environment variables are provided + for env_key, env_value in mcp_server.environments.items(): + if env_key not in server.get("envs", {}): + raise HTTPException( + status_code=400, + detail=f"Environment variable '{env_key}' not provided for MCP server {mcp_server.name}", + ) + + # Add the processed server + processed_servers.append( + { + "id": str(server["id"]), + "envs": server["envs"], + "tools": server["tools"], + } + ) + + processed_config["mcp_servers"] = processed_servers + elif "mcp_servers" in config: + processed_config["mcp_servers"] = config["mcp_servers"] # Process custom MCP servers - if "custom_mcp_servers" in config: - if config["custom_mcp_servers"] is not None: - processed_custom_servers = [] - for server in config["custom_mcp_servers"]: - # Validate URL format - if not server.get("url"): - raise HTTPException( - status_code=400, - detail="URL is required for custom MCP servers", - ) - - # Add the custom server - processed_custom_servers.append( - {"url": server["url"], "headers": server.get("headers", {})} + if ( + "custom_mcp_servers" in config + and config["custom_mcp_servers"] is not None + ): + processed_custom_servers = [] + for server in config["custom_mcp_servers"]: + # Validate URL format + if not server.get("url"): + raise HTTPException( + status_code=400, + detail="URL is required for custom MCP servers", ) - config["custom_mcp_servers"] = processed_custom_servers - else: - config["custom_mcp_servers"] = [] + # Add the custom server + processed_custom_servers.append( + {"url": server["url"], "headers": server.get("headers", {})} + ) + + processed_config["custom_mcp_servers"] = processed_custom_servers # Process sub-agents - if "sub_agents" in config: - if config["sub_agents"] is not None: - config["sub_agents"] = [ - str(agent_id) for agent_id in config["sub_agents"] - ] + if "sub_agents" in config and config["sub_agents"] is not None: + processed_config["sub_agents"] = [ + str(agent_id) for agent_id in config["sub_agents"] + ] # Process tools - if "tools" in config: - if config["tools"] is not None: - config["tools"] = [ - {"id": str(tool["id"]), "envs": tool["envs"]} - for tool in config["tools"] - ] + if "tools" in config and config["tools"] is not None: + processed_tools = [] + for tool in config["tools"]: + # Convert tool id to string + tool_id = tool["id"] - agent_data["config"] = config + # Validar envs para garantir que não é None + envs = tool.get("envs", {}) + if envs is None: + envs = {} + + processed_tools.append({"id": str(tool_id), "envs": envs}) + processed_config["tools"] = processed_tools + + agent_data["config"] = processed_config # Ensure all config objects are serializable (convert UUIDs to strings) if "config" in agent_data and agent_data["config"] is not None: