diff --git a/src/schemas/agent_config.py b/src/schemas/agent_config.py index 89710b02..9ae5b436 100644 --- a/src/schemas/agent_config.py +++ b/src/schemas/agent_config.py @@ -32,6 +32,18 @@ class MCPServerConfig(BaseModel): from_attributes = True +class CustomMCPServerConfig(BaseModel): + """Configuration of a custom MCP server""" + + url: str = Field(..., description="Server URL of the custom MCP server") + headers: Dict[str, str] = Field( + default_factory=dict, description="Headers for requests to the server" + ) + + class Config: + from_attributes = True + + class HTTPToolParameter(BaseModel): """Parameter of an HTTP tool""" @@ -115,6 +127,9 @@ class LLMConfig(BaseModel): mcp_servers: Optional[List[MCPServerConfig]] = Field( default=None, description="List of MCP servers" ) + custom_mcp_servers: Optional[List[CustomMCPServerConfig]] = Field( + default=None, description="List of custom MCP servers with URL and headers" + ) sub_agents: Optional[List[UUID]] = Field( default=None, description="List of IDs of sub-agents" ) diff --git a/src/services/agent_service.py b/src/services/agent_service.py index 07f4fd64..8a667c2c 100644 --- a/src/services/agent_service.py +++ b/src/services/agent_service.py @@ -175,6 +175,20 @@ async def create_agent(db: Session, agent: AgentCreate) -> Agent: # Process the configuration before creating the agent config = agent.config + if config is None: + config = {} + agent.config = config + + # Ensure config is a dictionary + if not isinstance(config, dict): + config = {} + agent.config = config + + # Generate automatic API key if not provided or empty + if not config.get("api_key") or config.get("api_key") == "": + logger.info("Generating automatic API key for new agent") + config["api_key"] = generate_api_key() + if isinstance(config, dict): # Process MCP servers if "mcp_servers" in config: @@ -212,6 +226,24 @@ async def create_agent(db: Session, agent: AgentCreate) -> Agent: config["mcp_servers"] = processed_servers + # Process custom MCP servers + if "custom_mcp_servers" in config: + processed_custom_servers = [] + for server in config["custom_mcp_servers"]: + # Validate URL format + if not server.get("url"): + raise HTTPException( + status_code=400, + detail="URL is required for custom MCP servers", + ) + + # Add the custom server + processed_custom_servers.append( + {"url": server["url"], "headers": server.get("headers", {})} + ) + + config["custom_mcp_servers"] = processed_custom_servers + # Process sub-agents if "sub_agents" in config: config["sub_agents"] = [ @@ -225,11 +257,6 @@ async def create_agent(db: Session, agent: AgentCreate) -> Agent: for tool in config["tools"] ] - # Generate automatic API key if not provided or empty - if not config.get("api_key") or config.get("api_key") == "": - logger.info("Generating automatic API key for new agent") - config["api_key"] = generate_api_key() - agent.config = config # Ensure all config objects are serializable (convert UUIDs to strings) @@ -399,6 +426,24 @@ async def update_agent( config["mcp_servers"] = processed_servers + # Process custom MCP servers + if "custom_mcp_servers" in config: + processed_custom_servers = [] + for server in config["custom_mcp_servers"]: + # Validate URL format + if not server.get("url"): + raise HTTPException( + status_code=400, + detail="URL is required for custom MCP servers", + ) + + # Add the custom server + processed_custom_servers.append( + {"url": server["url"], "headers": server.get("headers", {})} + ) + + config["custom_mcp_servers"] = processed_custom_servers + # Process sub-agents if "sub_agents" in config: config["sub_agents"] = [ diff --git a/src/services/mcp_service.py b/src/services/mcp_service.py index e53b20b0..c8146e85 100644 --- a/src/services/mcp_service.py +++ b/src/services/mcp_service.py @@ -135,7 +135,7 @@ class MCPService: # Registers the exit_stack with the AsyncExitStack await self.exit_stack.enter_async_context(exit_stack) logger.info( - f"Connected successfully. Added {len(filtered_tools)} tools." + f"MCP Server {mcp_server.name} connected successfully. Added {len(filtered_tools)} tools." ) else: logger.warning( @@ -146,6 +146,19 @@ class MCPService: logger.error(f"Error connecting to MCP server {server['id']}: {e}") continue + # Process custom MCP servers + for server in mcp_config.get("custom_mcp_servers", []): + try: + tools, exit_stack = await self._connect_to_mcp_server(server) + self.tools.extend(tools) + await self.exit_stack.enter_async_context(exit_stack) + logger.info( + f"Custom MCP server connected successfully. Added {len(tools)} tools." + ) + except Exception as e: + logger.error(f"Error connecting to MCP server {server['id']}: {e}") + continue + logger.info( f"MCP Toolset created successfully. Total of {len(self.tools)} tools." )