feat(api): integrate new AI engines and update chat routes for dynamic agent handling
This commit is contained in:
parent
9f176bf0e0
commit
cf24a7ce5d
@ -51,6 +51,9 @@ dependencies = [
|
||||
"langgraph==0.4.1",
|
||||
"opentelemetry-sdk==1.33.0",
|
||||
"opentelemetry-exporter-otlp==1.33.0",
|
||||
"mcp==1.9.0",
|
||||
"crewai==0.120.1",
|
||||
"crewai-tools==0.45.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
@ -39,6 +39,7 @@ from fastapi import (
|
||||
Header,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
from src.config.settings import settings
|
||||
from src.config.database import get_db
|
||||
from src.core.jwt_middleware import (
|
||||
get_jwt_token,
|
||||
@ -49,7 +50,8 @@ from src.services import (
|
||||
agent_service,
|
||||
)
|
||||
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse, FileData
|
||||
from src.services.agent_runner import run_agent, run_agent_stream
|
||||
from src.services.adk.agent_runner import run_agent as run_agent_adk, run_agent_stream
|
||||
from src.services.crewai.agent_runner import run_agent as run_agent_crewai
|
||||
from src.core.exceptions import AgentNotFoundError
|
||||
from src.services.service_providers import (
|
||||
session_service,
|
||||
@ -262,7 +264,7 @@ async def websocket_chat(
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
"/{agent_id}/{external_id}",
|
||||
response_model=ChatResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse},
|
||||
@ -272,20 +274,32 @@ async def websocket_chat(
|
||||
)
|
||||
async def chat(
|
||||
request: ChatRequest,
|
||||
agent_id: str,
|
||||
external_id: str,
|
||||
_=Depends(get_agent_by_api_key),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
final_response = await run_agent(
|
||||
request.agent_id,
|
||||
request.external_id,
|
||||
request.message,
|
||||
session_service,
|
||||
artifacts_service,
|
||||
memory_service,
|
||||
db,
|
||||
files=request.files,
|
||||
)
|
||||
if settings.AI_ENGINE == "adk":
|
||||
final_response = await run_agent_adk(
|
||||
agent_id,
|
||||
external_id,
|
||||
request.message,
|
||||
session_service,
|
||||
artifacts_service,
|
||||
memory_service,
|
||||
db,
|
||||
files=request.files,
|
||||
)
|
||||
elif settings.AI_ENGINE == "crewai":
|
||||
final_response = await run_agent_crewai(
|
||||
agent_id,
|
||||
external_id,
|
||||
request.message,
|
||||
session_service,
|
||||
db,
|
||||
files=request.files,
|
||||
)
|
||||
|
||||
return {
|
||||
"response": final_response["final_response"],
|
||||
|
@ -0,0 +1,3 @@
|
||||
from src.config.settings import settings
|
||||
|
||||
__all__ = ["settings"]
|
@ -57,6 +57,9 @@ class Settings(BaseSettings):
|
||||
"POSTGRES_CONNECTION_STRING", "postgresql://postgres:root@localhost:5432/evo_ai"
|
||||
)
|
||||
|
||||
# AI engine settings
|
||||
AI_ENGINE: str = os.getenv("AI_ENGINE", "adk")
|
||||
|
||||
# Logging settings
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
LOG_DIR: str = "logs"
|
||||
@ -83,11 +86,11 @@ class Settings(BaseSettings):
|
||||
|
||||
# Email provider settings
|
||||
EMAIL_PROVIDER: str = os.getenv("EMAIL_PROVIDER", "sendgrid")
|
||||
|
||||
|
||||
# SendGrid settings
|
||||
SENDGRID_API_KEY: str = os.getenv("SENDGRID_API_KEY", "")
|
||||
EMAIL_FROM: str = os.getenv("EMAIL_FROM", "noreply@yourdomain.com")
|
||||
|
||||
|
||||
# SMTP settings
|
||||
SMTP_HOST: str = os.getenv("SMTP_HOST", "")
|
||||
SMTP_PORT: int = int(os.getenv("SMTP_PORT", 587))
|
||||
@ -96,7 +99,7 @@ class Settings(BaseSettings):
|
||||
SMTP_USE_TLS: bool = os.getenv("SMTP_USE_TLS", "true").lower() == "true"
|
||||
SMTP_USE_SSL: bool = os.getenv("SMTP_USE_SSL", "false").lower() == "true"
|
||||
SMTP_FROM: str = os.getenv("SMTP_FROM", "")
|
||||
|
||||
|
||||
APP_URL: str = os.getenv("APP_URL", "http://localhost:8000")
|
||||
|
||||
# Server settings
|
||||
|
@ -43,9 +43,11 @@ class FileData(BaseModel):
|
||||
class ChatRequest(BaseModel):
|
||||
"""Model to represent a chat request."""
|
||||
|
||||
agent_id: str = Field(..., description="Agent ID to process the message")
|
||||
external_id: str = Field(..., description="External ID for user identification")
|
||||
message: str = Field(..., description="User message to the agent")
|
||||
agent_id: Optional[str] = Field(None, description="Agent ID to process the message")
|
||||
external_id: Optional[str] = Field(
|
||||
None, description="External ID for user identification"
|
||||
)
|
||||
files: Optional[List[FileData]] = Field(
|
||||
None, description="List of files attached to the message"
|
||||
)
|
||||
|
@ -1 +1 @@
|
||||
from .agent_runner import run_agent
|
||||
from .adk.agent_runner import run_agent
|
||||
|
@ -45,7 +45,7 @@ from src.services.agent_service import (
|
||||
)
|
||||
from src.services.mcp_server_service import get_mcp_server
|
||||
|
||||
from src.services.agent_runner import run_agent, run_agent_stream
|
||||
from src.services.adk.agent_runner import run_agent, run_agent_stream
|
||||
from src.services.service_providers import (
|
||||
session_service,
|
||||
artifacts_service,
|
||||
@ -388,7 +388,6 @@ class A2ATaskManager:
|
||||
self, request: SendTaskStreamingRequest, agent: Agent
|
||||
) -> AsyncIterable[SendTaskStreamingResponse]:
|
||||
"""Processes a task in streaming mode using the specified agent."""
|
||||
# Extrair e processar arquivos da mesma forma que no método _process_task
|
||||
query = self._extract_user_query(request.params)
|
||||
|
||||
try:
|
||||
@ -448,21 +447,19 @@ class A2ATaskManager:
|
||||
),
|
||||
)
|
||||
|
||||
# Use os arquivos processados do _extract_user_query
|
||||
files = getattr(self, "_last_processed_files", None)
|
||||
|
||||
# Log sobre os arquivos processados
|
||||
if files:
|
||||
logger.info(
|
||||
f"Streaming: Passando {len(files)} arquivos processados para run_agent_stream"
|
||||
f"Streaming: Uploading {len(files)} files to run_agent_stream"
|
||||
)
|
||||
for file_info in files:
|
||||
logger.info(
|
||||
f"Streaming: Arquivo sendo enviado: {file_info.filename} ({file_info.content_type})"
|
||||
f"Streaming: File being sent: {file_info.filename} ({file_info.content_type})"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Streaming: Nenhum arquivo processado disponível para enviar ao agente"
|
||||
"Streaming: No processed files available to send to the agent"
|
||||
)
|
||||
|
||||
async for chunk in run_agent_stream(
|
||||
@ -473,7 +470,7 @@ class A2ATaskManager:
|
||||
artifacts_service=artifacts_service,
|
||||
memory_service=memory_service,
|
||||
db=self.db,
|
||||
files=files, # Passar os arquivos processados para o streaming
|
||||
files=files,
|
||||
):
|
||||
try:
|
||||
chunk_data = json.loads(chunk)
|
||||
|
@ -36,11 +36,11 @@ from src.schemas.schemas import Agent
|
||||
from src.utils.logger import setup_logger
|
||||
from src.core.exceptions import AgentNotFoundError
|
||||
from src.services.agent_service import get_agent
|
||||
from src.services.custom_tools import CustomToolBuilder
|
||||
from src.services.mcp_service import MCPService
|
||||
from src.services.custom_agents.a2a_agent import A2ACustomAgent
|
||||
from src.services.custom_agents.workflow_agent import WorkflowAgent
|
||||
from src.services.custom_agents.task_agent import TaskAgent
|
||||
from src.services.adk.custom_tools import CustomToolBuilder
|
||||
from src.services.adk.mcp_service import MCPService
|
||||
from src.services.adk.custom_agents.a2a_agent import A2ACustomAgent
|
||||
from src.services.adk.custom_agents.workflow_agent import WorkflowAgent
|
||||
from src.services.adk.custom_agents.task_agent import TaskAgent
|
||||
from src.services.apikey_service import get_decrypted_api_key
|
||||
from sqlalchemy.orm import Session
|
||||
from contextlib import AsyncExitStack
|
@ -35,7 +35,7 @@ from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactServ
|
||||
from src.utils.logger import setup_logger
|
||||
from src.core.exceptions import AgentNotFoundError, InternalServerError
|
||||
from src.services.agent_service import get_agent
|
||||
from src.services.agent_builder import AgentBuilder
|
||||
from src.services.adk.agent_builder import AgentBuilder
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional, AsyncGenerator
|
||||
import asyncio
|
0
src/services/adk/custom_agents/__init__.py
Normal file
0
src/services/adk/custom_agents/__init__.py
Normal file
@ -162,7 +162,7 @@ class TaskAgent(BaseAgent):
|
||||
),
|
||||
)
|
||||
|
||||
from src.services.agent_builder import AgentBuilder
|
||||
from src.services.adk.agent_builder import AgentBuilder
|
||||
|
||||
print(f"Building agent in Task agent: {agent.name}")
|
||||
agent_builder = AgentBuilder(self.db)
|
@ -181,7 +181,7 @@ class WorkflowAgent(BaseAgent):
|
||||
return
|
||||
|
||||
# Import moved to inside the function to avoid circular import
|
||||
from src.services.agent_builder import AgentBuilder
|
||||
from src.services.adk.agent_builder import AgentBuilder
|
||||
|
||||
agent_builder = AgentBuilder(self.db)
|
||||
root_agent, exit_stack = await agent_builder.build_agent(agent)
|
219
src/services/crewai/agent_builder.py
Normal file
219
src/services/crewai/agent_builder.py
Normal file
@ -0,0 +1,219 @@
|
||||
"""
|
||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||
│ @author: Davidson Gomes │
|
||||
│ @file: agent_builder.py │
|
||||
│ Developed by: Davidson Gomes │
|
||||
│ Creation date: May 13, 2025 │
|
||||
│ Contact: contato@evolution-api.com │
|
||||
├──────────────────────────────────────────────────────────────────────────────┤
|
||||
│ @copyright © Evolution API 2025. All rights reserved. │
|
||||
│ Licensed under the Apache License, Version 2.0 │
|
||||
│ │
|
||||
│ You may not use this file except in compliance with the License. │
|
||||
│ You may obtain a copy of the License at │
|
||||
│ │
|
||||
│ http://www.apache.org/licenses/LICENSE-2.0 │
|
||||
│ │
|
||||
│ Unless required by applicable law or agreed to in writing, software │
|
||||
│ distributed under the License is distributed on an "AS IS" BASIS, │
|
||||
│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │
|
||||
│ See the License for the specific language governing permissions and │
|
||||
│ limitations under the License. │
|
||||
├──────────────────────────────────────────────────────────────────────────────┤
|
||||
│ @important │
|
||||
│ For any future changes to the code in this file, it is recommended to │
|
||||
│ include, together with the modification, the information of the developer │
|
||||
│ who changed it and the date of modification. │
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Optional
|
||||
from src.schemas.schemas import Agent
|
||||
from src.schemas.agent_config import AgentTask
|
||||
from src.services.crewai.custom_tool import CustomToolBuilder
|
||||
from src.services.crewai.mcp_service import MCPService
|
||||
from src.utils.logger import setup_logger
|
||||
from src.services.apikey_service import get_decrypted_api_key
|
||||
from sqlalchemy.orm import Session
|
||||
from contextlib import AsyncExitStack
|
||||
from crewai import LLM, Agent as LlmAgent, Crew, Task, Process
|
||||
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class AgentBuilder:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.custom_tool_builder = CustomToolBuilder()
|
||||
self.mcp_service = MCPService()
|
||||
|
||||
async def _get_api_key(self, agent: Agent) -> str:
|
||||
"""Get the API key for the agent."""
|
||||
api_key = None
|
||||
|
||||
# Get API key from api_key_id
|
||||
if hasattr(agent, "api_key_id") and agent.api_key_id:
|
||||
if decrypted_key := get_decrypted_api_key(self.db, agent.api_key_id):
|
||||
logger.info(f"Using stored API key for agent {agent.name}")
|
||||
api_key = decrypted_key
|
||||
else:
|
||||
logger.error(f"Stored API key not found for agent {agent.name}")
|
||||
raise ValueError(
|
||||
f"API key with ID {agent.api_key_id} not found or inactive"
|
||||
)
|
||||
else:
|
||||
# Check if there is an API key in the config (temporary field)
|
||||
config_api_key = agent.config.get("api_key") if agent.config else None
|
||||
if config_api_key:
|
||||
logger.info(f"Using config API key for agent {agent.name}")
|
||||
# Check if it is a UUID of a stored key
|
||||
try:
|
||||
key_id = uuid.UUID(config_api_key)
|
||||
if decrypted_key := get_decrypted_api_key(self.db, key_id):
|
||||
logger.info("Config API key is a valid reference")
|
||||
api_key = decrypted_key
|
||||
else:
|
||||
# Use the key directly
|
||||
api_key = config_api_key
|
||||
except (ValueError, TypeError):
|
||||
# It is not a UUID, use directly
|
||||
api_key = config_api_key
|
||||
else:
|
||||
logger.error(f"No API key configured for agent {agent.name}")
|
||||
raise ValueError(
|
||||
f"Agent {agent.name} does not have a configured API key"
|
||||
)
|
||||
|
||||
return api_key
|
||||
|
||||
async def _create_llm(self, agent: Agent) -> LLM:
|
||||
"""Create an LLM from the agent data."""
|
||||
api_key = await self._get_api_key(agent)
|
||||
|
||||
return LLM(model=agent.model, api_key=api_key)
|
||||
|
||||
async def _create_llm_agent(
|
||||
self, agent: Agent, enabled_tools: List[str] = []
|
||||
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
|
||||
"""Create an LLM agent from the agent data."""
|
||||
# Get custom tools from the configuration
|
||||
custom_tools = []
|
||||
custom_tools = self.custom_tool_builder.build_tools(agent.config)
|
||||
|
||||
# # Get MCP tools from the configuration
|
||||
mcp_tools = []
|
||||
mcp_exit_stack = None
|
||||
if agent.config.get("mcp_servers") or agent.config.get("custom_mcp_servers"):
|
||||
try:
|
||||
mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(
|
||||
agent.config, self.db
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error building MCP tools: {e}")
|
||||
# Continue without MCP tools
|
||||
mcp_tools = []
|
||||
mcp_exit_stack = None
|
||||
|
||||
# # Get agent tools
|
||||
# agent_tools = await self._agent_tools_builder(agent)
|
||||
|
||||
# Combine all tools
|
||||
all_tools = custom_tools + mcp_tools
|
||||
|
||||
if enabled_tools:
|
||||
all_tools = [tool for tool in all_tools if tool.name in enabled_tools]
|
||||
logger.info(f"Enabled tools enabled. Total tools: {len(all_tools)}")
|
||||
|
||||
now = datetime.now()
|
||||
current_datetime = now.strftime("%d/%m/%Y %H:%M")
|
||||
current_day_of_week = now.strftime("%A")
|
||||
current_date_iso = now.strftime("%Y-%m-%d")
|
||||
current_time = now.strftime("%H:%M")
|
||||
|
||||
# Substitute variables in the prompt
|
||||
formatted_prompt = agent.instruction.format(
|
||||
current_datetime=current_datetime,
|
||||
current_day_of_week=current_day_of_week,
|
||||
current_date_iso=current_date_iso,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
llm_agent = LlmAgent(
|
||||
role=agent.role,
|
||||
goal=agent.goal,
|
||||
backstory=formatted_prompt,
|
||||
llm=await self._create_llm(agent),
|
||||
tools=all_tools,
|
||||
verbose=True,
|
||||
cache=True,
|
||||
# memory=True,
|
||||
)
|
||||
|
||||
return llm_agent, mcp_exit_stack
|
||||
|
||||
async def _create_tasks(
|
||||
self, agent: LlmAgent, tasks: List[AgentTask] = []
|
||||
) -> List[Task]:
|
||||
"""Create tasks from the agent data."""
|
||||
tasks_list = []
|
||||
if tasks:
|
||||
tasks_list.extend(
|
||||
Task(
|
||||
name=task.name,
|
||||
description=task.description,
|
||||
expected_output=task.expected_output,
|
||||
agent=agent,
|
||||
verbose=True,
|
||||
)
|
||||
for task in tasks
|
||||
)
|
||||
return tasks_list
|
||||
|
||||
async def build_crew(self, agents: List[LlmAgent], tasks: List[Task] = []) -> Crew:
|
||||
"""Create a crew from the agent data."""
|
||||
return Crew(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
process=Process.sequential,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
async def build_llm_agent(
|
||||
self, root_agent, enabled_tools: List[str] = []
|
||||
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
|
||||
"""Build an LLM agent with its sub-agents."""
|
||||
logger.info("Creating LLM agent")
|
||||
|
||||
try:
|
||||
result = await self._create_llm_agent(root_agent, enabled_tools)
|
||||
|
||||
if isinstance(result, tuple) and len(result) == 2:
|
||||
return result
|
||||
else:
|
||||
return result, None
|
||||
except Exception as e:
|
||||
logger.error(f"Error in build_llm_agent: {e}")
|
||||
raise
|
||||
|
||||
async def build_agent(
|
||||
self, root_agent, enabled_tools: List[str] = []
|
||||
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
|
||||
"""Build the appropriate agent based on the type of the root agent."""
|
||||
if root_agent.type == "llm":
|
||||
agent, exit_stack = await self.build_llm_agent(root_agent, enabled_tools)
|
||||
return agent, exit_stack
|
||||
elif root_agent.type == "a2a":
|
||||
raise ValueError("A2A agents are not supported yet")
|
||||
# return await self.build_a2a_agent(root_agent)
|
||||
elif root_agent.type == "workflow":
|
||||
raise ValueError("Workflow agents are not supported yet")
|
||||
# return await self.build_workflow_agent(root_agent)
|
||||
elif root_agent.type == "task":
|
||||
raise ValueError("Task agents are not supported yet")
|
||||
# return await self.build_task_agent(root_agent)
|
||||
else:
|
||||
raise ValueError(f"Invalid agent type: {root_agent.type}")
|
||||
# return await self.build_composite_agent(root_agent)
|
595
src/services/crewai/agent_runner.py
Normal file
595
src/services/crewai/agent_runner.py
Normal file
@ -0,0 +1,595 @@
|
||||
"""
|
||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||
│ @author: Davidson Gomes │
|
||||
│ @file: agent_runner.py │
|
||||
│ Developed by: Davidson Gomes │
|
||||
│ Creation date: May 13, 2025 │
|
||||
│ Contact: contato@evolution-api.com │
|
||||
├──────────────────────────────────────────────────────────────────────────────┤
|
||||
│ @copyright © Evolution API 2025. All rights reserved. │
|
||||
│ Licensed under the Apache License, Version 2.0 │
|
||||
│ │
|
||||
│ You may not use this file except in compliance with the License. │
|
||||
│ You may obtain a copy of the License at │
|
||||
│ │
|
||||
│ http://www.apache.org/licenses/LICENSE-2.0 │
|
||||
│ │
|
||||
│ Unless required by applicable law or agreed to in writing, software │
|
||||
│ distributed under the License is distributed on an "AS IS" BASIS, │
|
||||
│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │
|
||||
│ See the License for the specific language governing permissions and │
|
||||
│ limitations under the License. │
|
||||
├──────────────────────────────────────────────────────────────────────────────┤
|
||||
│ @important │
|
||||
│ For any future changes to the code in this file, it is recommended to │
|
||||
│ include, together with the modification, the information of the developer │
|
||||
│ who changed it and the date of modification. │
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
"""
|
||||
|
||||
from crewai import Crew, Task, Agent as LlmAgent
|
||||
from src.services.crewai.session_service import (
|
||||
CrewSessionService,
|
||||
Event,
|
||||
Content,
|
||||
Part,
|
||||
Session,
|
||||
)
|
||||
from src.services.crewai.agent_builder import AgentBuilder
|
||||
from src.utils.logger import setup_logger
|
||||
from src.core.exceptions import AgentNotFoundError, InternalServerError
|
||||
from src.services.agent_service import get_agent
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional, AsyncGenerator
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from src.utils.otel import get_tracer
|
||||
from opentelemetry import trace
|
||||
import base64
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def extract_text_from_output(crew_output):
|
||||
"""Extract text from CrewOutput object."""
|
||||
if hasattr(crew_output, "raw") and crew_output.raw:
|
||||
return crew_output.raw
|
||||
elif hasattr(crew_output, "__str__"):
|
||||
return str(crew_output)
|
||||
|
||||
# Fallback if no text found
|
||||
return "Unable to extract a valid response."
|
||||
|
||||
|
||||
async def run_agent(
|
||||
agent_id: str,
|
||||
external_id: str,
|
||||
message: str,
|
||||
session_service: CrewSessionService,
|
||||
db: Session,
|
||||
session_id: Optional[str] = None,
|
||||
timeout: float = 60.0,
|
||||
files: Optional[list] = None,
|
||||
):
|
||||
tracer = get_tracer()
|
||||
with tracer.start_as_current_span(
|
||||
"run_agent",
|
||||
attributes={
|
||||
"agent_id": agent_id,
|
||||
"external_id": external_id,
|
||||
"session_id": session_id or f"{external_id}_{agent_id}",
|
||||
"message": message,
|
||||
"has_files": files is not None and len(files) > 0,
|
||||
},
|
||||
):
|
||||
exit_stack = None
|
||||
try:
|
||||
logger.info(
|
||||
f"Starting execution of agent {agent_id} for external_id {external_id}"
|
||||
)
|
||||
logger.info(f"Received message: {message}")
|
||||
|
||||
if files and len(files) > 0:
|
||||
logger.info(f"Received {len(files)} files with message")
|
||||
|
||||
get_root_agent = get_agent(db, agent_id)
|
||||
logger.info(
|
||||
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
|
||||
)
|
||||
|
||||
if get_root_agent is None:
|
||||
raise AgentNotFoundError(f"Agent with ID {agent_id} not found")
|
||||
|
||||
# Using the AgentBuilder to create the agent
|
||||
agent_builder = AgentBuilder(db)
|
||||
result = await agent_builder.build_agent(get_root_agent)
|
||||
|
||||
# Check how the result is structured
|
||||
if isinstance(result, tuple) and len(result) == 2:
|
||||
root_agent, exit_stack = result
|
||||
else:
|
||||
# If the result is not a tuple of 2 elements
|
||||
root_agent = result
|
||||
exit_stack = None
|
||||
logger.warning("build_agent did not return an exit_stack")
|
||||
|
||||
# TODO: files should be processed here
|
||||
|
||||
# Fetch session information
|
||||
crew_session_id = f"{external_id}_{agent_id}"
|
||||
if session_id is None:
|
||||
session_id = crew_session_id
|
||||
|
||||
logger.info(f"Searching session for external_id {external_id}")
|
||||
try:
|
||||
session = session_service.get_session(
|
||||
agent_id=agent_id,
|
||||
external_id=external_id,
|
||||
session_id=crew_session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting session: {str(e)}")
|
||||
session = None
|
||||
|
||||
if session is None:
|
||||
logger.info(f"Creating new session for external_id {external_id}")
|
||||
session = session_service.create_session(
|
||||
agent_id=agent_id,
|
||||
external_id=external_id,
|
||||
session_id=crew_session_id,
|
||||
)
|
||||
|
||||
# Add user message to session
|
||||
session.events.append(
|
||||
Event(
|
||||
author="user",
|
||||
content=Content(parts=[{"text": message}]),
|
||||
timestamp=datetime.now().timestamp(),
|
||||
)
|
||||
)
|
||||
|
||||
# Save session to database
|
||||
session_service.save_session(session)
|
||||
|
||||
# Build message history for context
|
||||
conversation_history = []
|
||||
if session and session.events:
|
||||
for event in session.events:
|
||||
if event.author and event.content and event.content.parts:
|
||||
for part in event.content.parts:
|
||||
if isinstance(part, dict) and "text" in part:
|
||||
role = "User" if event.author == "user" else "Assistant"
|
||||
conversation_history.append(f"{role}: {part['text']}")
|
||||
|
||||
# Build description with history as context
|
||||
task_description = (
|
||||
f"Conversation history:\n" + "\n".join(conversation_history)
|
||||
if conversation_history
|
||||
else ""
|
||||
)
|
||||
task_description += f"\n\nCurrent user message: {message}"
|
||||
|
||||
task = Task(
|
||||
name="resolve_user_request",
|
||||
description=task_description,
|
||||
expected_output="Response to the user request",
|
||||
agent=root_agent,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
crew = await agent_builder.build_crew([root_agent], [task])
|
||||
|
||||
# Use normal kickoff or kickoff_async instead of kickoff_for_each
|
||||
if hasattr(crew, "kickoff_async"):
|
||||
crew_output = await crew.kickoff_async(inputs={"message": message})
|
||||
else:
|
||||
loop = asyncio.get_event_loop()
|
||||
crew_output = await loop.run_in_executor(
|
||||
None, lambda: crew.kickoff(inputs={"message": message})
|
||||
)
|
||||
|
||||
# Extract response and add to session
|
||||
final_text = extract_text_from_output(crew_output)
|
||||
|
||||
# Add agent response as event in session
|
||||
session.events.append(
|
||||
Event(
|
||||
author=get_root_agent.name,
|
||||
content=Content(parts=[{"text": final_text}]),
|
||||
timestamp=datetime.now().timestamp(),
|
||||
)
|
||||
)
|
||||
|
||||
# Save session with new event
|
||||
session_service.save_session(session)
|
||||
|
||||
logger.info("Starting agent execution")
|
||||
|
||||
final_response_text = "No final response captured."
|
||||
message_history = []
|
||||
|
||||
try:
|
||||
response_queue = asyncio.Queue()
|
||||
execution_completed = asyncio.Event()
|
||||
|
||||
async def process_events():
|
||||
try:
|
||||
# Log the result
|
||||
logger.info(f"Crew output: {crew_output}")
|
||||
|
||||
# Signal that execution is complete
|
||||
execution_completed.set()
|
||||
|
||||
# Extract text from CrewOutput object
|
||||
final_text = "Unable to extract a valid response."
|
||||
|
||||
if hasattr(crew_output, "raw") and crew_output.raw:
|
||||
final_text = crew_output.raw
|
||||
elif hasattr(crew_output, "__str__"):
|
||||
final_text = str(crew_output)
|
||||
|
||||
# If still empty or None, check crew artifacts
|
||||
if not final_text or final_text.strip() == "":
|
||||
# Try to get from agent messages
|
||||
if hasattr(root_agent, "messages") and root_agent.messages:
|
||||
# Get the last message from the agent
|
||||
for msg in reversed(root_agent.messages):
|
||||
if hasattr(msg, "content") and msg.content:
|
||||
final_text = msg.content
|
||||
break
|
||||
|
||||
# If still empty, use a fallback
|
||||
if not final_text or final_text.strip() == "":
|
||||
final_text = "The agent could not produce a valid response. Please try again with a different question."
|
||||
|
||||
# Put the extracted text in the queue
|
||||
await response_queue.put(final_text)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process_events: {str(e)}")
|
||||
# Provide a more helpful error response
|
||||
error_response = f"An error occurred during processing: {str(e)}\n\nIf you are trying to use external tools such as Brave Search, please make sure the connection is working properly."
|
||||
await response_queue.put(error_response)
|
||||
execution_completed.set()
|
||||
|
||||
task = asyncio.create_task(process_events())
|
||||
|
||||
try:
|
||||
wait_task = asyncio.create_task(execution_completed.wait())
|
||||
done, pending = await asyncio.wait({wait_task}, timeout=timeout)
|
||||
|
||||
for p in pending:
|
||||
p.cancel()
|
||||
|
||||
if not execution_completed.is_set():
|
||||
logger.warning(
|
||||
f"Agent execution timed out after {timeout} seconds"
|
||||
)
|
||||
await response_queue.put(
|
||||
"The response took too long and was interrupted."
|
||||
)
|
||||
|
||||
final_response_text = await response_queue.get()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error waiting for response: {str(e)}")
|
||||
final_response_text = f"Error processing response: {str(e)}"
|
||||
|
||||
# Add the session to memory after completion
|
||||
# completed_session = session_service.get_session(
|
||||
# app_name=agent_id,
|
||||
# user_id=external_id,
|
||||
# session_id=crew_session_id,
|
||||
# )
|
||||
|
||||
# memory_service.add_session_to_memory(completed_session)
|
||||
|
||||
# Cancel the processing task if it is still running
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Task cancelled successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling task: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing request: {str(e)}")
|
||||
raise InternalServerError(str(e)) from e
|
||||
|
||||
logger.info("Agent execution completed successfully")
|
||||
return {
|
||||
"final_response": final_response_text,
|
||||
"message_history": message_history,
|
||||
}
|
||||
except AgentNotFoundError as e:
|
||||
logger.error(f"Error processing request: {str(e)}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Internal error processing request: {str(e)}", exc_info=True)
|
||||
raise InternalServerError(str(e))
|
||||
finally:
|
||||
# Clean up MCP connection - MUST be executed in the same task
|
||||
if exit_stack:
|
||||
logger.info("Closing MCP server connection...")
|
||||
try:
|
||||
if hasattr(exit_stack, "aclose"):
|
||||
# If it's an AsyncExitStack
|
||||
await exit_stack.aclose()
|
||||
elif isinstance(exit_stack, list):
|
||||
# If it's a list of adapters
|
||||
for adapter in exit_stack:
|
||||
if hasattr(adapter, "close"):
|
||||
adapter.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing MCP connection: {e}")
|
||||
# Do not raise the exception to not obscure the original error
|
||||
|
||||
|
||||
def convert_sets(obj):
|
||||
if isinstance(obj, set):
|
||||
return list(obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: convert_sets(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [convert_sets(i) for i in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
async def run_agent_stream(
|
||||
agent_id: str,
|
||||
external_id: str,
|
||||
message: str,
|
||||
db: Session,
|
||||
session_id: Optional[str] = None,
|
||||
files: Optional[list] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
tracer = get_tracer()
|
||||
span = tracer.start_span(
|
||||
"run_agent_stream",
|
||||
attributes={
|
||||
"agent_id": agent_id,
|
||||
"external_id": external_id,
|
||||
"session_id": session_id or f"{external_id}_{agent_id}",
|
||||
"message": message,
|
||||
"has_files": files is not None and len(files) > 0,
|
||||
},
|
||||
)
|
||||
exit_stack = None
|
||||
try:
|
||||
with trace.use_span(span, end_on_exit=True):
|
||||
try:
|
||||
logger.info(
|
||||
f"Starting streaming execution of agent {agent_id} for external_id {external_id}"
|
||||
)
|
||||
logger.info(f"Received message: {message}")
|
||||
|
||||
if files and len(files) > 0:
|
||||
logger.info(f"Received {len(files)} files with message")
|
||||
|
||||
get_root_agent = get_agent(db, agent_id)
|
||||
logger.info(
|
||||
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
|
||||
)
|
||||
|
||||
if get_root_agent is None:
|
||||
raise AgentNotFoundError(f"Agent with ID {agent_id} not found")
|
||||
|
||||
# Using the AgentBuilder to create the agent
|
||||
agent_builder = AgentBuilder(db)
|
||||
result = await agent_builder.build_agent(get_root_agent)
|
||||
|
||||
# Check how the result is structured
|
||||
if isinstance(result, tuple) and len(result) == 2:
|
||||
root_agent, exit_stack = result
|
||||
else:
|
||||
# If the result is not a tuple of 2 elements
|
||||
root_agent = result
|
||||
exit_stack = None
|
||||
logger.warning("build_agent did not return an exit_stack")
|
||||
|
||||
# TODO: files should be processed here
|
||||
|
||||
# Fetch session history if available
|
||||
session_id = f"{external_id}_{agent_id}"
|
||||
|
||||
# Create an instance of the session service
|
||||
try:
|
||||
from src.config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
db_url = settings.DATABASE_URL
|
||||
except ImportError:
|
||||
# Fallback to local SQLite if cannot import settings
|
||||
db_url = "sqlite:///data/crew_sessions.db"
|
||||
|
||||
session_service = CrewSessionService(db_url)
|
||||
|
||||
try:
|
||||
# Try to get existing session
|
||||
session = session_service.get_session(
|
||||
agent_id=agent_id,
|
||||
external_id=external_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load session: {e}")
|
||||
session = None
|
||||
|
||||
# Build message history for context
|
||||
conversation_history = []
|
||||
|
||||
if session and session.events:
|
||||
for event in session.events:
|
||||
if event.author and event.content and event.content.parts:
|
||||
for part in event.content.parts:
|
||||
if isinstance(part, dict) and "text" in part:
|
||||
role = (
|
||||
"User"
|
||||
if event.author == "user"
|
||||
else "Assistant"
|
||||
)
|
||||
conversation_history.append(
|
||||
f"{role}: {part['text']}"
|
||||
)
|
||||
|
||||
# Build description with history
|
||||
task_description = (
|
||||
f"Conversation history:\n" + "\n".join(conversation_history)
|
||||
if conversation_history
|
||||
else ""
|
||||
)
|
||||
task_description += f"\n\nCurrent user message: {message}"
|
||||
|
||||
task = Task(
|
||||
name="resolve_user_request",
|
||||
description=task_description,
|
||||
expected_output="Response to the user request",
|
||||
agent=root_agent,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
crew = await agent_builder.build_crew([root_agent], [task])
|
||||
|
||||
logger.info("Starting agent streaming execution")
|
||||
|
||||
try:
|
||||
# Check if we can process messages with kickoff_for_each
|
||||
if hasattr(crew, "kickoff_for_each"):
|
||||
# Create input with current message
|
||||
inputs = [{"message": message}]
|
||||
logger.info(
|
||||
f"Using kickoff_for_each for streaming with {len(inputs)} input(s)"
|
||||
)
|
||||
|
||||
# Execute kickoff_for_each
|
||||
results = crew.kickoff_for_each(inputs=inputs)
|
||||
|
||||
# Print results and save to session
|
||||
for i, result in enumerate(results):
|
||||
logger.info(f"Result of event {i+1}: {result}")
|
||||
|
||||
# If we have a session, save the response to it
|
||||
if session:
|
||||
# Add agent response as event
|
||||
session.events.append(
|
||||
Event(
|
||||
author="agent",
|
||||
content=Content(parts=[{"text": result}]),
|
||||
timestamp=datetime.now().timestamp(),
|
||||
)
|
||||
)
|
||||
|
||||
# Save current session with new message
|
||||
if session:
|
||||
# Also add user message if it doesn't exist yet
|
||||
if not any(
|
||||
e.author == "user"
|
||||
and any(
|
||||
p.get("text") == message for p in e.content.parts
|
||||
)
|
||||
for e in session.events
|
||||
if e.content and e.content.parts
|
||||
):
|
||||
session.events.append(
|
||||
Event(
|
||||
author="user",
|
||||
content=Content(parts=[{"text": message}]),
|
||||
timestamp=datetime.now().timestamp(),
|
||||
)
|
||||
)
|
||||
# Save session
|
||||
try:
|
||||
session_service.save_session(session)
|
||||
logger.info(f"Session saved successfully: {session_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving session: {e}")
|
||||
|
||||
# Use last result as final output
|
||||
crew_output = results[-1] if results else None
|
||||
else:
|
||||
# CrewAI kickoff method is synchronous, fallback if kickoff_for_each not available
|
||||
logger.info(
|
||||
"kickoff_for_each not available, using standard kickoff for streaming"
|
||||
)
|
||||
crew_output = crew.kickoff()
|
||||
|
||||
logger.info(f"Crew output: {crew_output}")
|
||||
|
||||
# Extract the actual text content
|
||||
if hasattr(crew_output, "raw") and crew_output.raw:
|
||||
final_output = crew_output.raw
|
||||
elif hasattr(crew_output, "__str__"):
|
||||
final_output = str(crew_output)
|
||||
else:
|
||||
final_output = "Could not extract text from response"
|
||||
|
||||
# Save response to session (for fallback case of normal kickoff)
|
||||
if session and not hasattr(crew, "kickoff_for_each"):
|
||||
# Add agent response
|
||||
session.events.append(
|
||||
Event(
|
||||
author="agent",
|
||||
content=Content(parts=[{"text": final_output}]),
|
||||
timestamp=datetime.now().timestamp(),
|
||||
)
|
||||
)
|
||||
|
||||
# Add user message if it doesn't exist yet
|
||||
if not any(
|
||||
e.author == "user"
|
||||
and any(p.get("text") == message for p in e.content.parts)
|
||||
for e in session.events
|
||||
if e.content and e.content.parts
|
||||
):
|
||||
session.events.append(
|
||||
Event(
|
||||
author="user",
|
||||
content=Content(parts=[{"text": message}]),
|
||||
timestamp=datetime.now().timestamp(),
|
||||
)
|
||||
)
|
||||
|
||||
# Save session
|
||||
try:
|
||||
session_service.save_session(session)
|
||||
logger.info(
|
||||
f"Session saved successfully (method: kickoff): {session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving session: {e}")
|
||||
|
||||
yield json.dumps({"text": final_output})
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing request: {str(e)}")
|
||||
raise InternalServerError(str(e)) from e
|
||||
finally:
|
||||
# Clean up MCP connection
|
||||
if exit_stack:
|
||||
logger.info("Closing MCP server connection...")
|
||||
try:
|
||||
if hasattr(exit_stack, "aclose"):
|
||||
# If it's an AsyncExitStack
|
||||
await exit_stack.aclose()
|
||||
elif isinstance(exit_stack, list):
|
||||
# If it's a list of adapters
|
||||
for adapter in exit_stack:
|
||||
if hasattr(adapter, "close"):
|
||||
adapter.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing MCP connection: {e}")
|
||||
# Do not raise the exception to not obscure the original error
|
||||
|
||||
logger.info("Agent streaming execution completed successfully")
|
||||
except AgentNotFoundError as e:
|
||||
logger.error(f"Error processing request: {str(e)}")
|
||||
raise InternalServerError(str(e)) from e
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Internal error processing request: {str(e)}", exc_info=True
|
||||
)
|
||||
raise InternalServerError(str(e))
|
||||
finally:
|
||||
span.end()
|
369
src/services/crewai/custom_tool.py
Normal file
369
src/services/crewai/custom_tool.py
Normal file
@ -0,0 +1,369 @@
|
||||
"""
|
||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||
│ @author: Davidson Gomes │
|
||||
│ @file: custom_tool.py │
|
||||
│ Developed by: Davidson Gomes │
|
||||
│ Creation date: May 13, 2025 │
|
||||
│ Contact: contato@evolution-api.com │
|
||||
├──────────────────────────────────────────────────────────────────────────────┤
|
||||
│ @copyright © Evolution API 2025. All rights reserved. │
|
||||
│ Licensed under the Apache License, Version 2.0 │
|
||||
│ │
|
||||
│ You may not use this file except in compliance with the License. │
|
||||
│ You may obtain a copy of the License at │
|
||||
│ │
|
||||
│ http://www.apache.org/licenses/LICENSE-2.0 │
|
||||
│ │
|
||||
│ Unless required by applicable law or agreed to in writing, software │
|
||||
│ distributed under the License is distributed on an "AS IS" BASIS, │
|
||||
│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │
|
||||
│ See the License for the specific language governing permissions and │
|
||||
│ limitations under the License. │
|
||||
├──────────────────────────────────────────────────────────────────────────────┤
|
||||
│ @important │
|
||||
│ For any future changes to the code in this file, it is recommended to │
|
||||
│ include, together with the modification, the information of the developer │
|
||||
│ who changed it and the date of modification. │
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Type
|
||||
from crewai.tools import BaseTool, tool
|
||||
import requests
|
||||
import json
|
||||
from src.utils.logger import setup_logger
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class CustomToolBuilder:
|
||||
def __init__(self):
|
||||
self.tools = []
|
||||
|
||||
def _create_http_tool(self, tool_config: Dict[str, Any]) -> BaseTool:
|
||||
"""Create an HTTP tool based on the provided configuration."""
|
||||
# Extract configuration parameters
|
||||
name = tool_config["name"]
|
||||
description = tool_config["description"]
|
||||
endpoint = tool_config["endpoint"]
|
||||
method = tool_config["method"]
|
||||
headers = tool_config.get("headers", {})
|
||||
parameters = tool_config.get("parameters", {}) or {}
|
||||
values = tool_config.get("values", {})
|
||||
error_handling = tool_config.get("error_handling", {})
|
||||
|
||||
path_params = parameters.get("path_params") or {}
|
||||
query_params = parameters.get("query_params") or {}
|
||||
body_params = parameters.get("body_params") or {}
|
||||
|
||||
# Dynamic creation of the input schema for the tool
|
||||
field_definitions = {}
|
||||
|
||||
# Add all parameters as fields
|
||||
for param in (
|
||||
list(path_params.keys())
|
||||
+ list(query_params.keys())
|
||||
+ list(body_params.keys())
|
||||
):
|
||||
# Default to string type for all parameters
|
||||
field_definitions[param] = (
|
||||
str,
|
||||
Field(..., description=f"Parameter {param}"),
|
||||
)
|
||||
|
||||
# If there are no parameters but default values, use those as optional fields
|
||||
if not field_definitions and values:
|
||||
for param, value in values.items():
|
||||
param_type = type(value)
|
||||
field_definitions[param] = (
|
||||
param_type,
|
||||
Field(default=value, description=f"Parameter {param}"),
|
||||
)
|
||||
|
||||
# Create dynamic input schema model in line with the documentation
|
||||
tool_input_model = create_model(
|
||||
f"{name.replace(' ', '')}Input", **field_definitions
|
||||
)
|
||||
|
||||
# Create the HTTP tool using crewai's BaseTool class
|
||||
# Following the pattern in the documentation
|
||||
def create_http_tool_class():
|
||||
# Capture variables from outer scope
|
||||
_name = name
|
||||
_description = description
|
||||
_tool_input_model = tool_input_model
|
||||
|
||||
class HttpTool(BaseTool):
|
||||
name: str = _name
|
||||
description: str = _description
|
||||
args_schema: Type[BaseModel] = _tool_input_model
|
||||
|
||||
def _run(self, **kwargs):
|
||||
"""Execute the HTTP request and return the result."""
|
||||
try:
|
||||
# Combines default values with provided values
|
||||
all_values = {**values, **kwargs}
|
||||
|
||||
# Substitutes placeholders in headers
|
||||
processed_headers = {
|
||||
k: v.format(**all_values) if isinstance(v, str) else v
|
||||
for k, v in headers.items()
|
||||
}
|
||||
|
||||
# Processes path parameters
|
||||
url = endpoint
|
||||
for param, value in path_params.items():
|
||||
if param in all_values:
|
||||
url = url.replace(
|
||||
f"{{{param}}}", str(all_values[param])
|
||||
)
|
||||
|
||||
# Process query parameters
|
||||
query_params_dict = {}
|
||||
for param, value in query_params.items():
|
||||
if isinstance(value, list):
|
||||
# If the value is a list, join with comma
|
||||
query_params_dict[param] = ",".join(value)
|
||||
elif param in all_values:
|
||||
# If the parameter is in the values, use the provided value
|
||||
query_params_dict[param] = all_values[param]
|
||||
else:
|
||||
# Otherwise, use the default value from the configuration
|
||||
query_params_dict[param] = value
|
||||
|
||||
# Adds default values to query params if they are not present
|
||||
for param, value in values.items():
|
||||
if (
|
||||
param not in query_params_dict
|
||||
and param not in path_params
|
||||
):
|
||||
query_params_dict[param] = value
|
||||
|
||||
body_data = {}
|
||||
for param, param_config in body_params.items():
|
||||
if param in all_values:
|
||||
body_data[param] = all_values[param]
|
||||
|
||||
# Adds default values to body if they are not present
|
||||
for param, value in values.items():
|
||||
if (
|
||||
param not in body_data
|
||||
and param not in query_params_dict
|
||||
and param not in path_params
|
||||
):
|
||||
body_data[param] = value
|
||||
|
||||
# Makes the HTTP request
|
||||
response = requests.request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=processed_headers,
|
||||
params=query_params_dict,
|
||||
json=body_data or None,
|
||||
timeout=error_handling.get("timeout", 30),
|
||||
)
|
||||
|
||||
if response.status_code >= 400:
|
||||
raise requests.exceptions.HTTPError(
|
||||
f"Error in the request: {response.status_code} - {response.text}"
|
||||
)
|
||||
|
||||
# Always returns the response as a string
|
||||
return json.dumps(response.json())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {name}: {str(e)}")
|
||||
return json.dumps(
|
||||
error_handling.get(
|
||||
"fallback_response",
|
||||
{"error": "tool_execution_error", "message": str(e)},
|
||||
)
|
||||
)
|
||||
|
||||
return HttpTool
|
||||
|
||||
# Create the tool instance
|
||||
HttpToolClass = create_http_tool_class()
|
||||
http_tool = HttpToolClass()
|
||||
|
||||
# Add cache function following the documentation
|
||||
def http_cache_function(arguments: dict, result: str) -> bool:
|
||||
"""Determines whether to cache the result based on arguments and result."""
|
||||
# Default implementation: cache all successful results
|
||||
try:
|
||||
# If the result is parseable JSON and not an error, cache it
|
||||
result_obj = json.loads(result)
|
||||
return not (isinstance(result_obj, dict) and "error" in result_obj)
|
||||
except Exception:
|
||||
# If result is not valid JSON, don't cache
|
||||
return False
|
||||
|
||||
# Assign the cache function to the tool
|
||||
http_tool.cache_function = http_cache_function
|
||||
|
||||
return http_tool
|
||||
|
||||
def _create_http_tool_with_decorator(self, tool_config: Dict[str, Any]) -> Any:
|
||||
"""Create an HTTP tool using the tool decorator."""
|
||||
# Extract configuration parameters
|
||||
name = tool_config["name"]
|
||||
description = tool_config["description"]
|
||||
endpoint = tool_config["endpoint"]
|
||||
method = tool_config["method"]
|
||||
headers = tool_config.get("headers", {})
|
||||
parameters = tool_config.get("parameters", {}) or {}
|
||||
values = tool_config.get("values", {})
|
||||
error_handling = tool_config.get("error_handling", {})
|
||||
|
||||
path_params = parameters.get("path_params") or {}
|
||||
query_params = parameters.get("query_params") or {}
|
||||
body_params = parameters.get("body_params") or {}
|
||||
|
||||
# Create function docstring with parameter documentation
|
||||
param_list = (
|
||||
list(path_params.keys())
|
||||
+ list(query_params.keys())
|
||||
+ list(body_params.keys())
|
||||
)
|
||||
doc_params = []
|
||||
for param in param_list:
|
||||
doc_params.append(f" {param}: Parameter description")
|
||||
|
||||
docstring = (
|
||||
f"{description}\n\nParameters:\n"
|
||||
+ "\n".join(doc_params)
|
||||
+ "\n\nReturns:\n String containing the response in JSON format"
|
||||
)
|
||||
|
||||
# Create the tool function using the decorator pattern in the documentation
|
||||
@tool(name=name)
|
||||
def http_tool(**kwargs):
|
||||
"""Tool function created dynamically."""
|
||||
try:
|
||||
# Combines default values with provided values
|
||||
all_values = {**values, **kwargs}
|
||||
|
||||
# Substitutes placeholders in headers
|
||||
processed_headers = {
|
||||
k: v.format(**all_values) if isinstance(v, str) else v
|
||||
for k, v in headers.items()
|
||||
}
|
||||
|
||||
# Processes path parameters
|
||||
url = endpoint
|
||||
for param, value in path_params.items():
|
||||
if param in all_values:
|
||||
url = url.replace(f"{{{param}}}", str(all_values[param]))
|
||||
|
||||
# Process query parameters
|
||||
query_params_dict = {}
|
||||
for param, value in query_params.items():
|
||||
if isinstance(value, list):
|
||||
# If the value is a list, join with comma
|
||||
query_params_dict[param] = ",".join(value)
|
||||
elif param in all_values:
|
||||
# If the parameter is in the values, use the provided value
|
||||
query_params_dict[param] = all_values[param]
|
||||
else:
|
||||
# Otherwise, use the default value from the configuration
|
||||
query_params_dict[param] = value
|
||||
|
||||
# Adds default values to query params if they are not present
|
||||
for param, value in values.items():
|
||||
if param not in query_params_dict and param not in path_params:
|
||||
query_params_dict[param] = value
|
||||
|
||||
body_data = {}
|
||||
for param, param_config in body_params.items():
|
||||
if param in all_values:
|
||||
body_data[param] = all_values[param]
|
||||
|
||||
# Adds default values to body if they are not present
|
||||
for param, value in values.items():
|
||||
if (
|
||||
param not in body_data
|
||||
and param not in query_params_dict
|
||||
and param not in path_params
|
||||
):
|
||||
body_data[param] = value
|
||||
|
||||
# Makes the HTTP request
|
||||
response = requests.request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=processed_headers,
|
||||
params=query_params_dict,
|
||||
json=body_data or None,
|
||||
timeout=error_handling.get("timeout", 30),
|
||||
)
|
||||
|
||||
if response.status_code >= 400:
|
||||
raise requests.exceptions.HTTPError(
|
||||
f"Error in the request: {response.status_code} - {response.text}"
|
||||
)
|
||||
|
||||
# Always returns the response as a string
|
||||
return json.dumps(response.json())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {name}: {str(e)}")
|
||||
return json.dumps(
|
||||
error_handling.get(
|
||||
"fallback_response",
|
||||
{"error": "tool_execution_error", "message": str(e)},
|
||||
)
|
||||
)
|
||||
|
||||
# Replace the docstring
|
||||
http_tool.__doc__ = docstring
|
||||
|
||||
# Add cache function following the documentation
|
||||
def http_cache_function(arguments: dict, result: str) -> bool:
|
||||
"""Determines whether to cache the result based on arguments and result."""
|
||||
# Default implementation: cache all successful results
|
||||
try:
|
||||
# If the result is parseable JSON and not an error, cache it
|
||||
result_obj = json.loads(result)
|
||||
return not (isinstance(result_obj, dict) and "error" in result_obj)
|
||||
except Exception:
|
||||
# If result is not valid JSON, don't cache
|
||||
return False
|
||||
|
||||
# Assign the cache function to the tool
|
||||
http_tool.cache_function = http_cache_function
|
||||
|
||||
return http_tool
|
||||
|
||||
def build_tools(self, tools_config: Dict[str, Any]) -> List[BaseTool]:
|
||||
"""Builds a list of tools based on the provided configuration. Accepts both 'tools' and 'custom_tools' (with http_tools)."""
|
||||
self.tools = []
|
||||
|
||||
# Find HTTP tools configuration in various possible locations
|
||||
http_tools = []
|
||||
if tools_config.get("http_tools"):
|
||||
http_tools = tools_config.get("http_tools", [])
|
||||
elif tools_config.get("custom_tools") and tools_config["custom_tools"].get(
|
||||
"http_tools"
|
||||
):
|
||||
http_tools = tools_config["custom_tools"].get("http_tools", [])
|
||||
elif (
|
||||
tools_config.get("tools")
|
||||
and isinstance(tools_config["tools"], dict)
|
||||
and tools_config["tools"].get("http_tools")
|
||||
):
|
||||
http_tools = tools_config["tools"].get("http_tools", [])
|
||||
|
||||
# Determine which implementation method to use (BaseTool or decorator)
|
||||
use_decorator = tools_config.get("use_decorator", False)
|
||||
|
||||
# Create tools for each HTTP tool configuration
|
||||
for http_tool_config in http_tools:
|
||||
if use_decorator:
|
||||
self.tools.append(
|
||||
self._create_http_tool_with_decorator(http_tool_config)
|
||||
)
|
||||
else:
|
||||
self.tools.append(self._create_http_tool(http_tool_config))
|
||||
|
||||
return self.tools
|
264
src/services/crewai/mcp_service.py
Normal file
264
src/services/crewai/mcp_service.py
Normal file
@ -0,0 +1,264 @@
|
||||
"""
|
||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||
│ @author: Davidson Gomes │
|
||||
│ @file: mcp_service.py │
|
||||
│ Developed by: Davidson Gomes │
|
||||
│ Creation date: May 13, 2025 │
|
||||
│ Contact: contato@evolution-api.com │
|
||||
├──────────────────────────────────────────────────────────────────────────────┤
|
||||
│ @copyright © Evolution API 2025. All rights reserved. │
|
||||
│ Licensed under the Apache License, Version 2.0 │
|
||||
│ │
|
||||
│ You may not use this file except in compliance with the License. │
|
||||
│ You may obtain a copy of the License at │
|
||||
│ │
|
||||
│ http://www.apache.org/licenses/LICENSE-2.0 │
|
||||
│ │
|
||||
│ Unless required by applicable law or agreed to in writing, software │
|
||||
│ distributed under the License is distributed on an "AS IS" BASIS, │
|
||||
│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │
|
||||
│ See the License for the specific language governing permissions and │
|
||||
│ limitations under the License. │
|
||||
├──────────────────────────────────────────────────────────────────────────────┤
|
||||
│ @important │
|
||||
│ For any future changes to the code in this file, it is recommended to │
|
||||
│ include, together with the modification, the information of the developer │
|
||||
│ who changed it and the date of modification. │
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from contextlib import ExitStack
|
||||
import os
|
||||
import sys
|
||||
from src.utils.logger import setup_logger
|
||||
from src.services.mcp_server_service import get_mcp_server
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
try:
|
||||
from crewai_tools import MCPServerAdapter
|
||||
from mcp import StdioServerParameters
|
||||
|
||||
HAS_MCP_PACKAGES = True
|
||||
except ImportError:
|
||||
logger = setup_logger(__name__)
|
||||
logger.error(
|
||||
"MCP packages are not installed. Please install mcp and crewai-tools[mcp]"
|
||||
)
|
||||
HAS_MCP_PACKAGES = False
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class MCPService:
|
||||
def __init__(self):
|
||||
self.tools = []
|
||||
self.exit_stack = ExitStack()
|
||||
|
||||
def _connect_to_mcp_server(
|
||||
self, server_config: Dict[str, Any]
|
||||
) -> Tuple[List[Any], Optional[ExitStack]]:
|
||||
"""Connect to a specific MCP server and return its tools."""
|
||||
if not HAS_MCP_PACKAGES:
|
||||
logger.error("Cannot connect to MCP server: MCP packages not installed")
|
||||
return [], None
|
||||
|
||||
try:
|
||||
# Determines the type of server (local or remote)
|
||||
if "url" in server_config:
|
||||
# Remote server (SSE) - Simplified approach using direct dictionary
|
||||
sse_config = {"url": server_config["url"]}
|
||||
|
||||
# Add headers if provided
|
||||
if "headers" in server_config and server_config["headers"]:
|
||||
sse_config["headers"] = server_config["headers"]
|
||||
|
||||
# Create the MCPServerAdapter with the SSE configuration
|
||||
mcp_adapter = MCPServerAdapter(sse_config)
|
||||
else:
|
||||
# Local server (Stdio)
|
||||
command = server_config.get("command", "npx")
|
||||
args = server_config.get("args", [])
|
||||
|
||||
# Adds environment variables if specified
|
||||
env = server_config.get("env", {})
|
||||
if env:
|
||||
for key, value in env.items():
|
||||
os.environ[key] = value
|
||||
|
||||
connection_params = StdioServerParameters(
|
||||
command=command, args=args, env=env
|
||||
)
|
||||
|
||||
# Create the MCPServerAdapter with the Stdio connection parameters
|
||||
mcp_adapter = MCPServerAdapter(connection_params)
|
||||
|
||||
# Get tools from the adapter
|
||||
tools = mcp_adapter.tools
|
||||
|
||||
# Return tools and the adapter (which serves as an exit stack)
|
||||
return tools, mcp_adapter
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to MCP server: {e}")
|
||||
return [], None
|
||||
|
||||
def _filter_incompatible_tools(self, tools: List[Any]) -> List[Any]:
|
||||
"""Filters incompatible tools with the model."""
|
||||
problematic_tools = [
|
||||
"create_pull_request_review", # This tool causes the 400 INVALID_ARGUMENT error
|
||||
]
|
||||
|
||||
filtered_tools = []
|
||||
removed_count = 0
|
||||
|
||||
for tool in tools:
|
||||
if tool.name in problematic_tools:
|
||||
logger.warning(f"Removing incompatible tool: {tool.name}")
|
||||
removed_count += 1
|
||||
else:
|
||||
filtered_tools.append(tool)
|
||||
|
||||
if removed_count > 0:
|
||||
logger.warning(f"Removed {removed_count} incompatible tools.")
|
||||
|
||||
return filtered_tools
|
||||
|
||||
def _filter_tools_by_agent(
|
||||
self, tools: List[Any], agent_tools: List[str]
|
||||
) -> List[Any]:
|
||||
"""Filters tools compatible with the agent."""
|
||||
if not agent_tools:
|
||||
return tools
|
||||
|
||||
filtered_tools = []
|
||||
for tool in tools:
|
||||
logger.info(f"Tool: {tool.name}")
|
||||
if tool.name in agent_tools:
|
||||
filtered_tools.append(tool)
|
||||
return filtered_tools
|
||||
|
||||
async def build_tools(
|
||||
self, mcp_config: Dict[str, Any], db: Session
|
||||
) -> Tuple[List[Any], Any]:
|
||||
"""Builds a list of tools from multiple MCP servers."""
|
||||
if not HAS_MCP_PACKAGES:
|
||||
logger.error("Cannot build MCP tools: MCP packages not installed")
|
||||
return [], None
|
||||
|
||||
self.tools = []
|
||||
self.exit_stack = ExitStack()
|
||||
adapter_list = []
|
||||
|
||||
try:
|
||||
mcp_servers = mcp_config.get("mcp_servers", [])
|
||||
if mcp_servers is not None:
|
||||
# Process each MCP server in the configuration
|
||||
for server in mcp_servers:
|
||||
try:
|
||||
# Search for the MCP server in the database
|
||||
mcp_server = get_mcp_server(db, server["id"])
|
||||
if not mcp_server:
|
||||
logger.warning(f"MCP Server not found: {server['id']}")
|
||||
continue
|
||||
|
||||
# Prepares the server configuration
|
||||
server_config = mcp_server.config_json.copy()
|
||||
|
||||
# Replaces the environment variables in the config_json
|
||||
if "env" in server_config and server_config["env"] is not None:
|
||||
for key, value in server_config["env"].items():
|
||||
if value and value.startswith("env@@"):
|
||||
env_key = value.replace("env@@", "")
|
||||
if server.get("envs") and env_key in server.get(
|
||||
"envs", {}
|
||||
):
|
||||
server_config["env"][key] = server["envs"][
|
||||
env_key
|
||||
]
|
||||
else:
|
||||
logger.warning(
|
||||
f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Connecting to MCP server: {mcp_server.name}")
|
||||
tools, adapter = self._connect_to_mcp_server(server_config)
|
||||
|
||||
if tools and adapter:
|
||||
# Filters incompatible tools
|
||||
filtered_tools = self._filter_incompatible_tools(tools)
|
||||
|
||||
# Filters tools compatible with the agent
|
||||
if agent_tools := server.get("tools", []):
|
||||
filtered_tools = self._filter_tools_by_agent(
|
||||
filtered_tools, agent_tools
|
||||
)
|
||||
self.tools.extend(filtered_tools)
|
||||
|
||||
# Add to the adapter list for cleanup later
|
||||
adapter_list.append(adapter)
|
||||
logger.info(
|
||||
f"MCP Server {mcp_server.name} connected successfully. Added {len(filtered_tools)} tools."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to connect or no tools available for {mcp_server.name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error connecting to MCP server {server.get('id', 'unknown')}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
custom_mcp_servers = mcp_config.get("custom_mcp_servers", [])
|
||||
if custom_mcp_servers is not None:
|
||||
# Process custom MCP servers
|
||||
for server in custom_mcp_servers:
|
||||
if not server:
|
||||
logger.warning(
|
||||
"Empty server configuration found in custom_mcp_servers"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"Connecting to custom MCP server: {server.get('url', 'unknown')}"
|
||||
)
|
||||
tools, adapter = self._connect_to_mcp_server(server)
|
||||
|
||||
if tools:
|
||||
self.tools.extend(tools)
|
||||
else:
|
||||
logger.warning("No tools returned from custom MCP server")
|
||||
continue
|
||||
|
||||
if adapter:
|
||||
adapter_list.append(adapter)
|
||||
logger.info(
|
||||
f"Custom MCP server connected successfully. Added {len(tools)} tools."
|
||||
)
|
||||
else:
|
||||
logger.warning("No adapter returned from custom MCP server")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error connecting to custom MCP server {server.get('url', 'unknown')}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"MCP Toolset created successfully. Total of {len(self.tools)} tools."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Ensure cleanup
|
||||
for adapter in adapter_list:
|
||||
if hasattr(adapter, "close"):
|
||||
adapter.close()
|
||||
logger.error(f"Fatal error connecting to MCP servers: {e}")
|
||||
# Return empty lists in case of error
|
||||
return [], None
|
||||
|
||||
# Return the tools and the adapter list for cleanup
|
||||
return self.tools, adapter_list
|
637
src/services/crewai/session_service.py
Normal file
637
src/services/crewai/session_service.py
Normal file
@ -0,0 +1,637 @@
|
||||
"""
|
||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||
│ @author: Davidson Gomes │
|
||||
│ @file: session_service.py │
|
||||
│ Developed by: Davidson Gomes │
|
||||
│ Creation date: May 13, 2025 │
|
||||
│ Contact: contato@evolution-api.com │
|
||||
├──────────────────────────────────────────────────────────────────────────────┤
|
||||
│ @copyright © Evolution API 2025. All rights reserved. │
|
||||
│ Licensed under the Apache License, Version 2.0 │
|
||||
│ │
|
||||
│ You may not use this file except in compliance with the License. │
|
||||
│ You may obtain a copy of the License at │
|
||||
│ │
|
||||
│ http://www.apache.org/licenses/LICENSE-2.0 │
|
||||
│ │
|
||||
│ Unless required by applicable law or agreed to in writing, software │
|
||||
│ distributed under the License is distributed on an "AS IS" BASIS, │
|
||||
│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │
|
||||
│ See the License for the specific language governing permissions and │
|
||||
│ limitations under the License. │
|
||||
├──────────────────────────────────────────────────────────────────────────────┤
|
||||
│ @important │
|
||||
│ For any future changes to the code in this file, it is recommended to │
|
||||
│ include, together with the modification, the information of the developer │
|
||||
│ who changed it and the date of modification. │
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
import uuid
|
||||
import base64
|
||||
import copy
|
||||
from typing import Any, Dict, List, Optional, Union, Set
|
||||
|
||||
from sqlalchemy import create_engine, Boolean, Text, ForeignKeyConstraint
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.orm import (
|
||||
sessionmaker,
|
||||
relationship,
|
||||
DeclarativeBase,
|
||||
Mapped,
|
||||
mapped_column,
|
||||
)
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.types import DateTime, PickleType, String
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class DynamicJSON(TypeDecorator):
|
||||
"""JSON type compatible with ADK that uses JSONB in PostgreSQL and TEXT with JSON
|
||||
serialization for other databases."""
|
||||
|
||||
impl = Text # Default implementation is TEXT
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(postgresql.JSONB)
|
||||
else:
|
||||
return dialect.type_descriptor(Text)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is not None:
|
||||
if dialect.name == "postgresql":
|
||||
return value
|
||||
else:
|
||||
return json.dumps(value)
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is not None:
|
||||
if dialect.name == "postgresql":
|
||||
return value
|
||||
else:
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for database tables."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StorageSession(Base):
|
||||
"""Represents a session stored in the database, compatible with ADK."""
|
||||
|
||||
__tablename__ = "sessions"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
id: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
|
||||
create_time: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
||||
update_time: Mapped[DateTime] = mapped_column(
|
||||
DateTime(), default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
storage_events: Mapped[list["StorageEvent"]] = relationship(
|
||||
"StorageEvent",
|
||||
back_populates="storage_session",
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<StorageSession(id={self.id}, update_time={self.update_time})>"
|
||||
|
||||
|
||||
class StorageEvent(Base):
|
||||
"""Represents an event stored in the database, compatible with ADK."""
|
||||
|
||||
__tablename__ = "events"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
session_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
|
||||
invocation_id: Mapped[str] = mapped_column(String)
|
||||
author: Mapped[str] = mapped_column(String)
|
||||
branch: Mapped[str] = mapped_column(String, nullable=True)
|
||||
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
||||
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
|
||||
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
|
||||
|
||||
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
|
||||
Text, nullable=True
|
||||
)
|
||||
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
DynamicJSON, nullable=True
|
||||
)
|
||||
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
error_code: Mapped[str] = mapped_column(String, nullable=True)
|
||||
error_message: Mapped[str] = mapped_column(String, nullable=True)
|
||||
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
storage_session: Mapped[StorageSession] = relationship(
|
||||
"StorageSession",
|
||||
back_populates="storage_events",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
ForeignKeyConstraint(
|
||||
["app_name", "user_id", "session_id"],
|
||||
["sessions.app_name", "sessions.user_id", "sessions.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def long_running_tool_ids(self) -> set[str]:
|
||||
return (
|
||||
set(json.loads(self.long_running_tool_ids_json))
|
||||
if self.long_running_tool_ids_json
|
||||
else set()
|
||||
)
|
||||
|
||||
@long_running_tool_ids.setter
|
||||
def long_running_tool_ids(self, value: set[str]):
|
||||
if value is None:
|
||||
self.long_running_tool_ids_json = None
|
||||
else:
|
||||
self.long_running_tool_ids_json = json.dumps(list(value))
|
||||
|
||||
|
||||
class StorageAppState(Base):
|
||||
"""Represents an application state stored in the database, compatible with ADK."""
|
||||
|
||||
__tablename__ = "app_states"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
update_time: Mapped[DateTime] = mapped_column(
|
||||
DateTime(), default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class StorageUserState(Base):
|
||||
"""Represents a user state stored in the database, compatible with ADK."""
|
||||
|
||||
__tablename__ = "user_states"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
update_time: Mapped[DateTime] = mapped_column(
|
||||
DateTime(), default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
# Pydantic model classes compatible with ADK
|
||||
class State:
|
||||
"""Utility class for states, compatible with ADK."""
|
||||
|
||||
APP_PREFIX = "app:"
|
||||
USER_PREFIX = "user:"
|
||||
TEMP_PREFIX = "temp:"
|
||||
|
||||
|
||||
class Content(BaseModel):
|
||||
"""Event content model, compatible with ADK."""
|
||||
|
||||
parts: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class Part(BaseModel):
|
||||
"""Content part model, compatible with ADK."""
|
||||
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
class Event(BaseModel):
|
||||
"""Event model, compatible with ADK."""
|
||||
|
||||
id: Optional[str] = None
|
||||
author: str
|
||||
branch: Optional[str] = None
|
||||
invocation_id: Optional[str] = None
|
||||
content: Optional[Content] = None
|
||||
actions: Optional[Dict[str, Any]] = None
|
||||
timestamp: Optional[float] = None
|
||||
long_running_tool_ids: Optional[Set[str]] = None
|
||||
grounding_metadata: Optional[Dict[str, Any]] = None
|
||||
partial: Optional[bool] = None
|
||||
turn_complete: Optional[bool] = None
|
||||
error_code: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
interrupted: Optional[bool] = None
|
||||
|
||||
|
||||
class Session(BaseModel):
|
||||
"""Session model, compatible with ADK."""
|
||||
|
||||
app_name: str
|
||||
user_id: str
|
||||
id: str
|
||||
state: Dict[str, Any] = {}
|
||||
events: List[Event] = []
|
||||
last_update_time: float
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class CrewSessionService:
|
||||
"""Service for managing CrewAI agent sessions using ADK tables."""
|
||||
|
||||
def __init__(self, db_url: str):
|
||||
"""
|
||||
Initializes the session service.
|
||||
|
||||
Args:
|
||||
db_url: Database connection URL.
|
||||
"""
|
||||
try:
|
||||
self.engine = create_engine(db_url)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to create database engine: {e}")
|
||||
|
||||
# Create all tables
|
||||
Base.metadata.create_all(self.engine)
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
logger.info(f"CrewSessionService started with database at {db_url}")
|
||||
|
||||
def create_session(
|
||||
self, agent_id: str, external_id: str, session_id: Optional[str] = None
|
||||
) -> Session:
|
||||
"""
|
||||
Creates a new session for an agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID (used as app_name in ADK)
|
||||
external_id: External ID (used as user_id in ADK)
|
||||
session_id: Optional session ID
|
||||
|
||||
Returns:
|
||||
Session: The created session
|
||||
"""
|
||||
session_id = session_id or str(uuid.uuid4())
|
||||
|
||||
with self.Session() as db_session:
|
||||
# Check if app and user states already exist
|
||||
storage_app_state = db_session.get(StorageAppState, (agent_id))
|
||||
storage_user_state = db_session.get(
|
||||
StorageUserState, (agent_id, external_id)
|
||||
)
|
||||
|
||||
app_state = storage_app_state.state if storage_app_state else {}
|
||||
user_state = storage_user_state.state if storage_user_state else {}
|
||||
|
||||
# Create states if they don't exist
|
||||
if not storage_app_state:
|
||||
storage_app_state = StorageAppState(app_name=agent_id, state={})
|
||||
db_session.add(storage_app_state)
|
||||
|
||||
if not storage_user_state:
|
||||
storage_user_state = StorageUserState(
|
||||
app_name=agent_id, user_id=external_id, state={}
|
||||
)
|
||||
db_session.add(storage_user_state)
|
||||
|
||||
# Create session
|
||||
storage_session = StorageSession(
|
||||
app_name=agent_id,
|
||||
user_id=external_id,
|
||||
id=session_id,
|
||||
state={},
|
||||
)
|
||||
db_session.add(storage_session)
|
||||
db_session.commit()
|
||||
|
||||
# Get timestamp
|
||||
db_session.refresh(storage_session)
|
||||
|
||||
# Merge states for response
|
||||
merged_state = _merge_state(app_state, user_state, {})
|
||||
|
||||
# Create Session object for return
|
||||
session = Session(
|
||||
app_name=agent_id,
|
||||
user_id=external_id,
|
||||
id=session_id,
|
||||
state=merged_state,
|
||||
last_update_time=storage_session.update_time.timestamp(),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Session created: {session_id} for agent {agent_id} and user {external_id}"
|
||||
)
|
||||
return session
|
||||
|
||||
def get_session(
|
||||
self, agent_id: str, external_id: str, session_id: str
|
||||
) -> Optional[Session]:
|
||||
"""
|
||||
Retrieves a session from the database.
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
external_id: User ID
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
Optional[Session]: The retrieved session or None if not found
|
||||
"""
|
||||
with self.Session() as db_session:
|
||||
storage_session = db_session.get(
|
||||
StorageSession, (agent_id, external_id, session_id)
|
||||
)
|
||||
|
||||
if storage_session is None:
|
||||
return None
|
||||
|
||||
# Fetch session events
|
||||
storage_events = (
|
||||
db_session.query(StorageEvent)
|
||||
.filter(StorageEvent.session_id == storage_session.id)
|
||||
.filter(StorageEvent.app_name == agent_id)
|
||||
.filter(StorageEvent.user_id == external_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Fetch states
|
||||
storage_app_state = db_session.get(StorageAppState, (agent_id))
|
||||
storage_user_state = db_session.get(
|
||||
StorageUserState, (agent_id, external_id)
|
||||
)
|
||||
|
||||
app_state = storage_app_state.state if storage_app_state else {}
|
||||
user_state = storage_user_state.state if storage_user_state else {}
|
||||
session_state = storage_session.state
|
||||
|
||||
# Merge states
|
||||
merged_state = _merge_state(app_state, user_state, session_state)
|
||||
|
||||
# Create session
|
||||
session = Session(
|
||||
app_name=agent_id,
|
||||
user_id=external_id,
|
||||
id=session_id,
|
||||
state=merged_state,
|
||||
last_update_time=storage_session.update_time.timestamp(),
|
||||
)
|
||||
|
||||
# Add events
|
||||
session.events = [
|
||||
Event(
|
||||
id=e.id,
|
||||
author=e.author,
|
||||
branch=e.branch,
|
||||
invocation_id=e.invocation_id,
|
||||
content=_decode_content(e.content),
|
||||
actions=e.actions,
|
||||
timestamp=e.timestamp.timestamp(),
|
||||
long_running_tool_ids=e.long_running_tool_ids,
|
||||
grounding_metadata=e.grounding_metadata,
|
||||
partial=e.partial,
|
||||
turn_complete=e.turn_complete,
|
||||
error_code=e.error_code,
|
||||
error_message=e.error_message,
|
||||
interrupted=e.interrupted,
|
||||
)
|
||||
for e in storage_events
|
||||
]
|
||||
|
||||
return session
|
||||
|
||||
def save_session(self, session: Session) -> None:
|
||||
"""
|
||||
Saves a session to the database.
|
||||
|
||||
Args:
|
||||
session: The session to save
|
||||
"""
|
||||
with self.Session() as db_session:
|
||||
storage_session = db_session.get(
|
||||
StorageSession, (session.app_name, session.user_id, session.id)
|
||||
)
|
||||
|
||||
if not storage_session:
|
||||
logger.error(f"Session not found: {session.id}")
|
||||
return
|
||||
|
||||
# Check states
|
||||
storage_app_state = db_session.get(StorageAppState, (session.app_name))
|
||||
storage_user_state = db_session.get(
|
||||
StorageUserState, (session.app_name, session.user_id)
|
||||
)
|
||||
|
||||
# Extract state deltas
|
||||
app_state_delta = {}
|
||||
user_state_delta = {}
|
||||
session_state_delta = {}
|
||||
|
||||
# Apply state deltas
|
||||
if storage_app_state and app_state_delta:
|
||||
storage_app_state.state.update(app_state_delta)
|
||||
|
||||
if storage_user_state and user_state_delta:
|
||||
storage_user_state.state.update(user_state_delta)
|
||||
|
||||
storage_session.state.update(session_state_delta)
|
||||
|
||||
# Save new events
|
||||
for event in session.events:
|
||||
# Check if event already exists
|
||||
existing_event = (
|
||||
(
|
||||
db_session.query(StorageEvent)
|
||||
.filter(StorageEvent.id == event.id)
|
||||
.filter(StorageEvent.app_name == session.app_name)
|
||||
.filter(StorageEvent.user_id == session.user_id)
|
||||
.filter(StorageEvent.session_id == session.id)
|
||||
.first()
|
||||
)
|
||||
if event.id
|
||||
else None
|
||||
)
|
||||
|
||||
if existing_event:
|
||||
continue
|
||||
|
||||
# Generate ID for the event if it doesn't exist
|
||||
if not event.id:
|
||||
event.id = str(uuid.uuid4())
|
||||
|
||||
# Create timestamp if it doesn't exist
|
||||
if not event.timestamp:
|
||||
event.timestamp = datetime.now().timestamp()
|
||||
|
||||
# Create StorageEvent object
|
||||
storage_event = StorageEvent(
|
||||
id=event.id,
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
session_id=session.id,
|
||||
invocation_id=event.invocation_id or str(uuid.uuid4()),
|
||||
author=event.author,
|
||||
branch=event.branch,
|
||||
timestamp=datetime.fromtimestamp(event.timestamp),
|
||||
actions=event.actions or {},
|
||||
long_running_tool_ids=event.long_running_tool_ids or set(),
|
||||
grounding_metadata=event.grounding_metadata,
|
||||
partial=event.partial,
|
||||
turn_complete=event.turn_complete,
|
||||
error_code=event.error_code,
|
||||
error_message=event.error_message,
|
||||
interrupted=event.interrupted,
|
||||
)
|
||||
|
||||
# Encode content, if it exists
|
||||
if event.content:
|
||||
encoded_content = event.content.model_dump(exclude_none=True)
|
||||
# Solution for serialization issues with multimedia content
|
||||
for p in encoded_content.get("parts", []):
|
||||
if "inline_data" in p:
|
||||
p["inline_data"]["data"] = (
|
||||
base64.b64encode(p["inline_data"]["data"]).decode(
|
||||
"utf-8"
|
||||
),
|
||||
)
|
||||
storage_event.content = encoded_content
|
||||
|
||||
db_session.add(storage_event)
|
||||
|
||||
# Commit changes
|
||||
db_session.commit()
|
||||
|
||||
# Update timestamp in session
|
||||
db_session.refresh(storage_session)
|
||||
session.last_update_time = storage_session.update_time.timestamp()
|
||||
|
||||
logger.info(f"Session saved: {session.id} with {len(session.events)} events")
|
||||
|
||||
def list_sessions(self, agent_id: str, external_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Lists all sessions for an agent and user.
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
external_id: User ID
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of summarized sessions
|
||||
"""
|
||||
with self.Session() as db_session:
|
||||
sessions = (
|
||||
db_session.query(StorageSession)
|
||||
.filter(StorageSession.app_name == agent_id)
|
||||
.filter(StorageSession.user_id == external_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
result = []
|
||||
for session in sessions:
|
||||
result.append(
|
||||
{
|
||||
"app_name": session.app_name,
|
||||
"user_id": session.user_id,
|
||||
"id": session.id,
|
||||
"created_at": session.create_time.isoformat(),
|
||||
"updated_at": session.update_time.isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def delete_session(self, agent_id: str, external_id: str, session_id: str) -> bool:
|
||||
"""
|
||||
Deletes a session from the database.
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
external_id: User ID
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
bool: True if the session was deleted, False otherwise
|
||||
"""
|
||||
from sqlalchemy import delete
|
||||
|
||||
with self.Session() as db_session:
|
||||
stmt = delete(StorageSession).where(
|
||||
StorageSession.app_name == agent_id,
|
||||
StorageSession.user_id == external_id,
|
||||
StorageSession.id == session_id,
|
||||
)
|
||||
result = db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
|
||||
logger.info(f"Session deleted: {session_id}")
|
||||
return result.rowcount > 0
|
||||
|
||||
|
||||
# Utility functions compatible with ADK
|
||||
|
||||
|
||||
def _extract_state_delta(state: dict[str, Any]):
|
||||
"""Extracts state deltas between app, user, and session."""
|
||||
app_state_delta = {}
|
||||
user_state_delta = {}
|
||||
session_state_delta = {}
|
||||
|
||||
if state:
|
||||
for key in state.keys():
|
||||
if key.startswith(State.APP_PREFIX):
|
||||
app_state_delta[key.removeprefix(State.APP_PREFIX)] = state[key]
|
||||
elif key.startswith(State.USER_PREFIX):
|
||||
user_state_delta[key.removeprefix(State.USER_PREFIX)] = state[key]
|
||||
elif not key.startswith(State.TEMP_PREFIX):
|
||||
session_state_delta[key] = state[key]
|
||||
|
||||
return app_state_delta, user_state_delta, session_state_delta
|
||||
|
||||
|
||||
def _merge_state(app_state, user_state, session_state):
|
||||
"""Merges app, user, and session states into a single object."""
|
||||
merged_state = copy.deepcopy(session_state)
|
||||
|
||||
for key in app_state.keys():
|
||||
merged_state[State.APP_PREFIX + key] = app_state[key]
|
||||
|
||||
for key in user_state.keys():
|
||||
merged_state[State.USER_PREFIX + key] = user_state[key]
|
||||
|
||||
return merged_state
|
||||
|
||||
|
||||
def _decode_content(content: Optional[dict[str, Any]]) -> Optional[Content]:
|
||||
"""Decodes event content potentially with binary data."""
|
||||
if not content:
|
||||
return None
|
||||
|
||||
for p in content.get("parts", []):
|
||||
if "inline_data" in p and isinstance(p["inline_data"].get("data"), tuple):
|
||||
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
|
||||
|
||||
return Content.model_validate(content)
|
@ -27,12 +27,22 @@
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
"""
|
||||
|
||||
from src.config.settings import settings
|
||||
import os
|
||||
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from google.adk.sessions import DatabaseSessionService
|
||||
from google.adk.memory import InMemoryMemoryService
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
from src.services.crewai.session_service import CrewSessionService
|
||||
|
||||
if os.getenv("AI_ENGINE") == "crewai":
|
||||
session_service = CrewSessionService(db_url=os.getenv("POSTGRES_CONNECTION_STRING"))
|
||||
else:
|
||||
session_service = DatabaseSessionService(
|
||||
db_url=os.getenv("POSTGRES_CONNECTION_STRING")
|
||||
)
|
||||
|
||||
# Initialize service instances
|
||||
session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING)
|
||||
artifacts_service = InMemoryArtifactService()
|
||||
memory_service = InMemoryMemoryService()
|
||||
|
Loading…
Reference in New Issue
Block a user