feat(api): integrate new AI engines and update chat routes for dynamic agent handling
This commit is contained in:
parent
9f176bf0e0
commit
cf24a7ce5d
@ -51,6 +51,9 @@ dependencies = [
|
|||||||
"langgraph==0.4.1",
|
"langgraph==0.4.1",
|
||||||
"opentelemetry-sdk==1.33.0",
|
"opentelemetry-sdk==1.33.0",
|
||||||
"opentelemetry-exporter-otlp==1.33.0",
|
"opentelemetry-exporter-otlp==1.33.0",
|
||||||
|
"mcp==1.9.0",
|
||||||
|
"crewai==0.120.1",
|
||||||
|
"crewai-tools==0.45.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
@ -39,6 +39,7 @@ from fastapi import (
|
|||||||
Header,
|
Header,
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from src.config.settings import settings
|
||||||
from src.config.database import get_db
|
from src.config.database import get_db
|
||||||
from src.core.jwt_middleware import (
|
from src.core.jwt_middleware import (
|
||||||
get_jwt_token,
|
get_jwt_token,
|
||||||
@ -49,7 +50,8 @@ from src.services import (
|
|||||||
agent_service,
|
agent_service,
|
||||||
)
|
)
|
||||||
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse, FileData
|
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse, FileData
|
||||||
from src.services.agent_runner import run_agent, run_agent_stream
|
from src.services.adk.agent_runner import run_agent as run_agent_adk, run_agent_stream
|
||||||
|
from src.services.crewai.agent_runner import run_agent as run_agent_crewai
|
||||||
from src.core.exceptions import AgentNotFoundError
|
from src.core.exceptions import AgentNotFoundError
|
||||||
from src.services.service_providers import (
|
from src.services.service_providers import (
|
||||||
session_service,
|
session_service,
|
||||||
@ -262,7 +264,7 @@ async def websocket_chat(
|
|||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"",
|
"/{agent_id}/{external_id}",
|
||||||
response_model=ChatResponse,
|
response_model=ChatResponse,
|
||||||
responses={
|
responses={
|
||||||
400: {"model": ErrorResponse},
|
400: {"model": ErrorResponse},
|
||||||
@ -272,20 +274,32 @@ async def websocket_chat(
|
|||||||
)
|
)
|
||||||
async def chat(
|
async def chat(
|
||||||
request: ChatRequest,
|
request: ChatRequest,
|
||||||
|
agent_id: str,
|
||||||
|
external_id: str,
|
||||||
_=Depends(get_agent_by_api_key),
|
_=Depends(get_agent_by_api_key),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
final_response = await run_agent(
|
if settings.AI_ENGINE == "adk":
|
||||||
request.agent_id,
|
final_response = await run_agent_adk(
|
||||||
request.external_id,
|
agent_id,
|
||||||
request.message,
|
external_id,
|
||||||
session_service,
|
request.message,
|
||||||
artifacts_service,
|
session_service,
|
||||||
memory_service,
|
artifacts_service,
|
||||||
db,
|
memory_service,
|
||||||
files=request.files,
|
db,
|
||||||
)
|
files=request.files,
|
||||||
|
)
|
||||||
|
elif settings.AI_ENGINE == "crewai":
|
||||||
|
final_response = await run_agent_crewai(
|
||||||
|
agent_id,
|
||||||
|
external_id,
|
||||||
|
request.message,
|
||||||
|
session_service,
|
||||||
|
db,
|
||||||
|
files=request.files,
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"response": final_response["final_response"],
|
"response": final_response["final_response"],
|
||||||
|
@ -0,0 +1,3 @@
|
|||||||
|
from src.config.settings import settings
|
||||||
|
|
||||||
|
__all__ = ["settings"]
|
@ -57,6 +57,9 @@ class Settings(BaseSettings):
|
|||||||
"POSTGRES_CONNECTION_STRING", "postgresql://postgres:root@localhost:5432/evo_ai"
|
"POSTGRES_CONNECTION_STRING", "postgresql://postgres:root@localhost:5432/evo_ai"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# AI engine settings
|
||||||
|
AI_ENGINE: str = os.getenv("AI_ENGINE", "adk")
|
||||||
|
|
||||||
# Logging settings
|
# Logging settings
|
||||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||||
LOG_DIR: str = "logs"
|
LOG_DIR: str = "logs"
|
||||||
@ -83,11 +86,11 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
# Email provider settings
|
# Email provider settings
|
||||||
EMAIL_PROVIDER: str = os.getenv("EMAIL_PROVIDER", "sendgrid")
|
EMAIL_PROVIDER: str = os.getenv("EMAIL_PROVIDER", "sendgrid")
|
||||||
|
|
||||||
# SendGrid settings
|
# SendGrid settings
|
||||||
SENDGRID_API_KEY: str = os.getenv("SENDGRID_API_KEY", "")
|
SENDGRID_API_KEY: str = os.getenv("SENDGRID_API_KEY", "")
|
||||||
EMAIL_FROM: str = os.getenv("EMAIL_FROM", "noreply@yourdomain.com")
|
EMAIL_FROM: str = os.getenv("EMAIL_FROM", "noreply@yourdomain.com")
|
||||||
|
|
||||||
# SMTP settings
|
# SMTP settings
|
||||||
SMTP_HOST: str = os.getenv("SMTP_HOST", "")
|
SMTP_HOST: str = os.getenv("SMTP_HOST", "")
|
||||||
SMTP_PORT: int = int(os.getenv("SMTP_PORT", 587))
|
SMTP_PORT: int = int(os.getenv("SMTP_PORT", 587))
|
||||||
@ -96,7 +99,7 @@ class Settings(BaseSettings):
|
|||||||
SMTP_USE_TLS: bool = os.getenv("SMTP_USE_TLS", "true").lower() == "true"
|
SMTP_USE_TLS: bool = os.getenv("SMTP_USE_TLS", "true").lower() == "true"
|
||||||
SMTP_USE_SSL: bool = os.getenv("SMTP_USE_SSL", "false").lower() == "true"
|
SMTP_USE_SSL: bool = os.getenv("SMTP_USE_SSL", "false").lower() == "true"
|
||||||
SMTP_FROM: str = os.getenv("SMTP_FROM", "")
|
SMTP_FROM: str = os.getenv("SMTP_FROM", "")
|
||||||
|
|
||||||
APP_URL: str = os.getenv("APP_URL", "http://localhost:8000")
|
APP_URL: str = os.getenv("APP_URL", "http://localhost:8000")
|
||||||
|
|
||||||
# Server settings
|
# Server settings
|
||||||
|
@ -43,9 +43,11 @@ class FileData(BaseModel):
|
|||||||
class ChatRequest(BaseModel):
|
class ChatRequest(BaseModel):
|
||||||
"""Model to represent a chat request."""
|
"""Model to represent a chat request."""
|
||||||
|
|
||||||
agent_id: str = Field(..., description="Agent ID to process the message")
|
|
||||||
external_id: str = Field(..., description="External ID for user identification")
|
|
||||||
message: str = Field(..., description="User message to the agent")
|
message: str = Field(..., description="User message to the agent")
|
||||||
|
agent_id: Optional[str] = Field(None, description="Agent ID to process the message")
|
||||||
|
external_id: Optional[str] = Field(
|
||||||
|
None, description="External ID for user identification"
|
||||||
|
)
|
||||||
files: Optional[List[FileData]] = Field(
|
files: Optional[List[FileData]] = Field(
|
||||||
None, description="List of files attached to the message"
|
None, description="List of files attached to the message"
|
||||||
)
|
)
|
||||||
|
@ -1 +1 @@
|
|||||||
from .agent_runner import run_agent
|
from .adk.agent_runner import run_agent
|
||||||
|
@ -45,7 +45,7 @@ from src.services.agent_service import (
|
|||||||
)
|
)
|
||||||
from src.services.mcp_server_service import get_mcp_server
|
from src.services.mcp_server_service import get_mcp_server
|
||||||
|
|
||||||
from src.services.agent_runner import run_agent, run_agent_stream
|
from src.services.adk.agent_runner import run_agent, run_agent_stream
|
||||||
from src.services.service_providers import (
|
from src.services.service_providers import (
|
||||||
session_service,
|
session_service,
|
||||||
artifacts_service,
|
artifacts_service,
|
||||||
@ -388,7 +388,6 @@ class A2ATaskManager:
|
|||||||
self, request: SendTaskStreamingRequest, agent: Agent
|
self, request: SendTaskStreamingRequest, agent: Agent
|
||||||
) -> AsyncIterable[SendTaskStreamingResponse]:
|
) -> AsyncIterable[SendTaskStreamingResponse]:
|
||||||
"""Processes a task in streaming mode using the specified agent."""
|
"""Processes a task in streaming mode using the specified agent."""
|
||||||
# Extrair e processar arquivos da mesma forma que no método _process_task
|
|
||||||
query = self._extract_user_query(request.params)
|
query = self._extract_user_query(request.params)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -448,21 +447,19 @@ class A2ATaskManager:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use os arquivos processados do _extract_user_query
|
|
||||||
files = getattr(self, "_last_processed_files", None)
|
files = getattr(self, "_last_processed_files", None)
|
||||||
|
|
||||||
# Log sobre os arquivos processados
|
|
||||||
if files:
|
if files:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Streaming: Passando {len(files)} arquivos processados para run_agent_stream"
|
f"Streaming: Uploading {len(files)} files to run_agent_stream"
|
||||||
)
|
)
|
||||||
for file_info in files:
|
for file_info in files:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Streaming: Arquivo sendo enviado: {file_info.filename} ({file_info.content_type})"
|
f"Streaming: File being sent: {file_info.filename} ({file_info.content_type})"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Streaming: Nenhum arquivo processado disponível para enviar ao agente"
|
"Streaming: No processed files available to send to the agent"
|
||||||
)
|
)
|
||||||
|
|
||||||
async for chunk in run_agent_stream(
|
async for chunk in run_agent_stream(
|
||||||
@ -473,7 +470,7 @@ class A2ATaskManager:
|
|||||||
artifacts_service=artifacts_service,
|
artifacts_service=artifacts_service,
|
||||||
memory_service=memory_service,
|
memory_service=memory_service,
|
||||||
db=self.db,
|
db=self.db,
|
||||||
files=files, # Passar os arquivos processados para o streaming
|
files=files,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
chunk_data = json.loads(chunk)
|
chunk_data = json.loads(chunk)
|
||||||
|
@ -36,11 +36,11 @@ from src.schemas.schemas import Agent
|
|||||||
from src.utils.logger import setup_logger
|
from src.utils.logger import setup_logger
|
||||||
from src.core.exceptions import AgentNotFoundError
|
from src.core.exceptions import AgentNotFoundError
|
||||||
from src.services.agent_service import get_agent
|
from src.services.agent_service import get_agent
|
||||||
from src.services.custom_tools import CustomToolBuilder
|
from src.services.adk.custom_tools import CustomToolBuilder
|
||||||
from src.services.mcp_service import MCPService
|
from src.services.adk.mcp_service import MCPService
|
||||||
from src.services.custom_agents.a2a_agent import A2ACustomAgent
|
from src.services.adk.custom_agents.a2a_agent import A2ACustomAgent
|
||||||
from src.services.custom_agents.workflow_agent import WorkflowAgent
|
from src.services.adk.custom_agents.workflow_agent import WorkflowAgent
|
||||||
from src.services.custom_agents.task_agent import TaskAgent
|
from src.services.adk.custom_agents.task_agent import TaskAgent
|
||||||
from src.services.apikey_service import get_decrypted_api_key
|
from src.services.apikey_service import get_decrypted_api_key
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
@ -35,7 +35,7 @@ from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactServ
|
|||||||
from src.utils.logger import setup_logger
|
from src.utils.logger import setup_logger
|
||||||
from src.core.exceptions import AgentNotFoundError, InternalServerError
|
from src.core.exceptions import AgentNotFoundError, InternalServerError
|
||||||
from src.services.agent_service import get_agent
|
from src.services.agent_service import get_agent
|
||||||
from src.services.agent_builder import AgentBuilder
|
from src.services.adk.agent_builder import AgentBuilder
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing import Optional, AsyncGenerator
|
from typing import Optional, AsyncGenerator
|
||||||
import asyncio
|
import asyncio
|
0
src/services/adk/custom_agents/__init__.py
Normal file
0
src/services/adk/custom_agents/__init__.py
Normal file
@ -162,7 +162,7 @@ class TaskAgent(BaseAgent):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
from src.services.agent_builder import AgentBuilder
|
from src.services.adk.agent_builder import AgentBuilder
|
||||||
|
|
||||||
print(f"Building agent in Task agent: {agent.name}")
|
print(f"Building agent in Task agent: {agent.name}")
|
||||||
agent_builder = AgentBuilder(self.db)
|
agent_builder = AgentBuilder(self.db)
|
@ -181,7 +181,7 @@ class WorkflowAgent(BaseAgent):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Import moved to inside the function to avoid circular import
|
# Import moved to inside the function to avoid circular import
|
||||||
from src.services.agent_builder import AgentBuilder
|
from src.services.adk.agent_builder import AgentBuilder
|
||||||
|
|
||||||
agent_builder = AgentBuilder(self.db)
|
agent_builder = AgentBuilder(self.db)
|
||||||
root_agent, exit_stack = await agent_builder.build_agent(agent)
|
root_agent, exit_stack = await agent_builder.build_agent(agent)
|
219
src/services/crewai/agent_builder.py
Normal file
219
src/services/crewai/agent_builder.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
"""
|
||||||
|
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ @author: Davidson Gomes │
|
||||||
|
│ @file: agent_builder.py │
|
||||||
|
│ Developed by: Davidson Gomes │
|
||||||
|
│ Creation date: May 13, 2025 │
|
||||||
|
│ Contact: contato@evolution-api.com │
|
||||||
|
├──────────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ @copyright © Evolution API 2025. All rights reserved. │
|
||||||
|
│ Licensed under the Apache License, Version 2.0 │
|
||||||
|
│ │
|
||||||
|
│ You may not use this file except in compliance with the License. │
|
||||||
|
│ You may obtain a copy of the License at │
|
||||||
|
│ │
|
||||||
|
│ http://www.apache.org/licenses/LICENSE-2.0 │
|
||||||
|
│ │
|
||||||
|
│ Unless required by applicable law or agreed to in writing, software │
|
||||||
|
│ distributed under the License is distributed on an "AS IS" BASIS, │
|
||||||
|
│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │
|
||||||
|
│ See the License for the specific language governing permissions and │
|
||||||
|
│ limitations under the License. │
|
||||||
|
├──────────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ @important │
|
||||||
|
│ For any future changes to the code in this file, it is recommended to │
|
||||||
|
│ include, together with the modification, the information of the developer │
|
||||||
|
│ who changed it and the date of modification. │
|
||||||
|
└──────────────────────────────────────────────────────────────────────────────┘
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
from src.schemas.schemas import Agent
|
||||||
|
from src.schemas.agent_config import AgentTask
|
||||||
|
from src.services.crewai.custom_tool import CustomToolBuilder
|
||||||
|
from src.services.crewai.mcp_service import MCPService
|
||||||
|
from src.utils.logger import setup_logger
|
||||||
|
from src.services.apikey_service import get_decrypted_api_key
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
|
from crewai import LLM, Agent as LlmAgent, Crew, Task, Process
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentBuilder:
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
self.custom_tool_builder = CustomToolBuilder()
|
||||||
|
self.mcp_service = MCPService()
|
||||||
|
|
||||||
|
async def _get_api_key(self, agent: Agent) -> str:
|
||||||
|
"""Get the API key for the agent."""
|
||||||
|
api_key = None
|
||||||
|
|
||||||
|
# Get API key from api_key_id
|
||||||
|
if hasattr(agent, "api_key_id") and agent.api_key_id:
|
||||||
|
if decrypted_key := get_decrypted_api_key(self.db, agent.api_key_id):
|
||||||
|
logger.info(f"Using stored API key for agent {agent.name}")
|
||||||
|
api_key = decrypted_key
|
||||||
|
else:
|
||||||
|
logger.error(f"Stored API key not found for agent {agent.name}")
|
||||||
|
raise ValueError(
|
||||||
|
f"API key with ID {agent.api_key_id} not found or inactive"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Check if there is an API key in the config (temporary field)
|
||||||
|
config_api_key = agent.config.get("api_key") if agent.config else None
|
||||||
|
if config_api_key:
|
||||||
|
logger.info(f"Using config API key for agent {agent.name}")
|
||||||
|
# Check if it is a UUID of a stored key
|
||||||
|
try:
|
||||||
|
key_id = uuid.UUID(config_api_key)
|
||||||
|
if decrypted_key := get_decrypted_api_key(self.db, key_id):
|
||||||
|
logger.info("Config API key is a valid reference")
|
||||||
|
api_key = decrypted_key
|
||||||
|
else:
|
||||||
|
# Use the key directly
|
||||||
|
api_key = config_api_key
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# It is not a UUID, use directly
|
||||||
|
api_key = config_api_key
|
||||||
|
else:
|
||||||
|
logger.error(f"No API key configured for agent {agent.name}")
|
||||||
|
raise ValueError(
|
||||||
|
f"Agent {agent.name} does not have a configured API key"
|
||||||
|
)
|
||||||
|
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
async def _create_llm(self, agent: Agent) -> LLM:
|
||||||
|
"""Create an LLM from the agent data."""
|
||||||
|
api_key = await self._get_api_key(agent)
|
||||||
|
|
||||||
|
return LLM(model=agent.model, api_key=api_key)
|
||||||
|
|
||||||
|
async def _create_llm_agent(
|
||||||
|
self, agent: Agent, enabled_tools: List[str] = []
|
||||||
|
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
|
||||||
|
"""Create an LLM agent from the agent data."""
|
||||||
|
# Get custom tools from the configuration
|
||||||
|
custom_tools = []
|
||||||
|
custom_tools = self.custom_tool_builder.build_tools(agent.config)
|
||||||
|
|
||||||
|
# # Get MCP tools from the configuration
|
||||||
|
mcp_tools = []
|
||||||
|
mcp_exit_stack = None
|
||||||
|
if agent.config.get("mcp_servers") or agent.config.get("custom_mcp_servers"):
|
||||||
|
try:
|
||||||
|
mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(
|
||||||
|
agent.config, self.db
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error building MCP tools: {e}")
|
||||||
|
# Continue without MCP tools
|
||||||
|
mcp_tools = []
|
||||||
|
mcp_exit_stack = None
|
||||||
|
|
||||||
|
# # Get agent tools
|
||||||
|
# agent_tools = await self._agent_tools_builder(agent)
|
||||||
|
|
||||||
|
# Combine all tools
|
||||||
|
all_tools = custom_tools + mcp_tools
|
||||||
|
|
||||||
|
if enabled_tools:
|
||||||
|
all_tools = [tool for tool in all_tools if tool.name in enabled_tools]
|
||||||
|
logger.info(f"Enabled tools enabled. Total tools: {len(all_tools)}")
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
current_datetime = now.strftime("%d/%m/%Y %H:%M")
|
||||||
|
current_day_of_week = now.strftime("%A")
|
||||||
|
current_date_iso = now.strftime("%Y-%m-%d")
|
||||||
|
current_time = now.strftime("%H:%M")
|
||||||
|
|
||||||
|
# Substitute variables in the prompt
|
||||||
|
formatted_prompt = agent.instruction.format(
|
||||||
|
current_datetime=current_datetime,
|
||||||
|
current_day_of_week=current_day_of_week,
|
||||||
|
current_date_iso=current_date_iso,
|
||||||
|
current_time=current_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_agent = LlmAgent(
|
||||||
|
role=agent.role,
|
||||||
|
goal=agent.goal,
|
||||||
|
backstory=formatted_prompt,
|
||||||
|
llm=await self._create_llm(agent),
|
||||||
|
tools=all_tools,
|
||||||
|
verbose=True,
|
||||||
|
cache=True,
|
||||||
|
# memory=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return llm_agent, mcp_exit_stack
|
||||||
|
|
||||||
|
async def _create_tasks(
|
||||||
|
self, agent: LlmAgent, tasks: List[AgentTask] = []
|
||||||
|
) -> List[Task]:
|
||||||
|
"""Create tasks from the agent data."""
|
||||||
|
tasks_list = []
|
||||||
|
if tasks:
|
||||||
|
tasks_list.extend(
|
||||||
|
Task(
|
||||||
|
name=task.name,
|
||||||
|
description=task.description,
|
||||||
|
expected_output=task.expected_output,
|
||||||
|
agent=agent,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
for task in tasks
|
||||||
|
)
|
||||||
|
return tasks_list
|
||||||
|
|
||||||
|
async def build_crew(self, agents: List[LlmAgent], tasks: List[Task] = []) -> Crew:
|
||||||
|
"""Create a crew from the agent data."""
|
||||||
|
return Crew(
|
||||||
|
agents=agents,
|
||||||
|
tasks=tasks,
|
||||||
|
process=Process.sequential,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def build_llm_agent(
|
||||||
|
self, root_agent, enabled_tools: List[str] = []
|
||||||
|
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
|
||||||
|
"""Build an LLM agent with its sub-agents."""
|
||||||
|
logger.info("Creating LLM agent")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self._create_llm_agent(root_agent, enabled_tools)
|
||||||
|
|
||||||
|
if isinstance(result, tuple) and len(result) == 2:
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return result, None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in build_llm_agent: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def build_agent(
|
||||||
|
self, root_agent, enabled_tools: List[str] = []
|
||||||
|
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
|
||||||
|
"""Build the appropriate agent based on the type of the root agent."""
|
||||||
|
if root_agent.type == "llm":
|
||||||
|
agent, exit_stack = await self.build_llm_agent(root_agent, enabled_tools)
|
||||||
|
return agent, exit_stack
|
||||||
|
elif root_agent.type == "a2a":
|
||||||
|
raise ValueError("A2A agents are not supported yet")
|
||||||
|
# return await self.build_a2a_agent(root_agent)
|
||||||
|
elif root_agent.type == "workflow":
|
||||||
|
raise ValueError("Workflow agents are not supported yet")
|
||||||
|
# return await self.build_workflow_agent(root_agent)
|
||||||
|
elif root_agent.type == "task":
|
||||||
|
raise ValueError("Task agents are not supported yet")
|
||||||
|
# return await self.build_task_agent(root_agent)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid agent type: {root_agent.type}")
|
||||||
|
# return await self.build_composite_agent(root_agent)
|
595
src/services/crewai/agent_runner.py
Normal file
595
src/services/crewai/agent_runner.py
Normal file
@ -0,0 +1,595 @@
|
|||||||
|
"""
|
||||||
|
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ @author: Davidson Gomes │
|
||||||
|
│ @file: agent_runner.py │
|
||||||
|
│ Developed by: Davidson Gomes │
|
||||||
|
│ Creation date: May 13, 2025 │
|
||||||
|
│ Contact: contato@evolution-api.com │
|
||||||
|
├──────────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ @copyright © Evolution API 2025. All rights reserved. │
|
||||||
|
│ Licensed under the Apache License, Version 2.0 │
|
||||||
|
│ │
|
||||||
|
│ You may not use this file except in compliance with the License. │
|
||||||
|
│ You may obtain a copy of the License at │
|
||||||
|
│ │
|
||||||
|
│ http://www.apache.org/licenses/LICENSE-2.0 │
|
||||||
|
│ │
|
||||||
|
│ Unless required by applicable law or agreed to in writing, software │
|
||||||
|
│ distributed under the License is distributed on an "AS IS" BASIS, │
|
||||||
|
│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │
|
||||||
|
│ See the License for the specific language governing permissions and │
|
||||||
|
│ limitations under the License. │
|
||||||
|
├──────────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ @important │
|
||||||
|
│ For any future changes to the code in this file, it is recommended to │
|
||||||
|
│ include, together with the modification, the information of the developer │
|
||||||
|
│ who changed it and the date of modification. │
|
||||||
|
└──────────────────────────────────────────────────────────────────────────────┘
|
||||||
|
"""
|
||||||
|
|
||||||
|
from crewai import Crew, Task, Agent as LlmAgent
|
||||||
|
from src.services.crewai.session_service import (
|
||||||
|
CrewSessionService,
|
||||||
|
Event,
|
||||||
|
Content,
|
||||||
|
Part,
|
||||||
|
Session,
|
||||||
|
)
|
||||||
|
from src.services.crewai.agent_builder import AgentBuilder
|
||||||
|
from src.utils.logger import setup_logger
|
||||||
|
from src.core.exceptions import AgentNotFoundError, InternalServerError
|
||||||
|
from src.services.agent_service import get_agent
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Optional, AsyncGenerator
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from src.utils.otel import get_tracer
|
||||||
|
from opentelemetry import trace
|
||||||
|
import base64
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text_from_output(crew_output):
|
||||||
|
"""Extract text from CrewOutput object."""
|
||||||
|
if hasattr(crew_output, "raw") and crew_output.raw:
|
||||||
|
return crew_output.raw
|
||||||
|
elif hasattr(crew_output, "__str__"):
|
||||||
|
return str(crew_output)
|
||||||
|
|
||||||
|
# Fallback if no text found
|
||||||
|
return "Unable to extract a valid response."
|
||||||
|
|
||||||
|
|
||||||
|
async def run_agent(
|
||||||
|
agent_id: str,
|
||||||
|
external_id: str,
|
||||||
|
message: str,
|
||||||
|
session_service: CrewSessionService,
|
||||||
|
db: Session,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
timeout: float = 60.0,
|
||||||
|
files: Optional[list] = None,
|
||||||
|
):
|
||||||
|
tracer = get_tracer()
|
||||||
|
with tracer.start_as_current_span(
|
||||||
|
"run_agent",
|
||||||
|
attributes={
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"external_id": external_id,
|
||||||
|
"session_id": session_id or f"{external_id}_{agent_id}",
|
||||||
|
"message": message,
|
||||||
|
"has_files": files is not None and len(files) > 0,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
exit_stack = None
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"Starting execution of agent {agent_id} for external_id {external_id}"
|
||||||
|
)
|
||||||
|
logger.info(f"Received message: {message}")
|
||||||
|
|
||||||
|
if files and len(files) > 0:
|
||||||
|
logger.info(f"Received {len(files)} files with message")
|
||||||
|
|
||||||
|
get_root_agent = get_agent(db, agent_id)
|
||||||
|
logger.info(
|
||||||
|
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if get_root_agent is None:
|
||||||
|
raise AgentNotFoundError(f"Agent with ID {agent_id} not found")
|
||||||
|
|
||||||
|
# Using the AgentBuilder to create the agent
|
||||||
|
agent_builder = AgentBuilder(db)
|
||||||
|
result = await agent_builder.build_agent(get_root_agent)
|
||||||
|
|
||||||
|
# Check how the result is structured
|
||||||
|
if isinstance(result, tuple) and len(result) == 2:
|
||||||
|
root_agent, exit_stack = result
|
||||||
|
else:
|
||||||
|
# If the result is not a tuple of 2 elements
|
||||||
|
root_agent = result
|
||||||
|
exit_stack = None
|
||||||
|
logger.warning("build_agent did not return an exit_stack")
|
||||||
|
|
||||||
|
# TODO: files should be processed here
|
||||||
|
|
||||||
|
# Fetch session information
|
||||||
|
crew_session_id = f"{external_id}_{agent_id}"
|
||||||
|
if session_id is None:
|
||||||
|
session_id = crew_session_id
|
||||||
|
|
||||||
|
logger.info(f"Searching session for external_id {external_id}")
|
||||||
|
try:
|
||||||
|
session = session_service.get_session(
|
||||||
|
agent_id=agent_id,
|
||||||
|
external_id=external_id,
|
||||||
|
session_id=crew_session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error getting session: {str(e)}")
|
||||||
|
session = None
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
logger.info(f"Creating new session for external_id {external_id}")
|
||||||
|
session = session_service.create_session(
|
||||||
|
agent_id=agent_id,
|
||||||
|
external_id=external_id,
|
||||||
|
session_id=crew_session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add user message to session
|
||||||
|
session.events.append(
|
||||||
|
Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(parts=[{"text": message}]),
|
||||||
|
timestamp=datetime.now().timestamp(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save session to database
|
||||||
|
session_service.save_session(session)
|
||||||
|
|
||||||
|
# Build message history for context
|
||||||
|
conversation_history = []
|
||||||
|
if session and session.events:
|
||||||
|
for event in session.events:
|
||||||
|
if event.author and event.content and event.content.parts:
|
||||||
|
for part in event.content.parts:
|
||||||
|
if isinstance(part, dict) and "text" in part:
|
||||||
|
role = "User" if event.author == "user" else "Assistant"
|
||||||
|
conversation_history.append(f"{role}: {part['text']}")
|
||||||
|
|
||||||
|
# Build description with history as context
|
||||||
|
task_description = (
|
||||||
|
f"Conversation history:\n" + "\n".join(conversation_history)
|
||||||
|
if conversation_history
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
task_description += f"\n\nCurrent user message: {message}"
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
name="resolve_user_request",
|
||||||
|
description=task_description,
|
||||||
|
expected_output="Response to the user request",
|
||||||
|
agent=root_agent,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = await agent_builder.build_crew([root_agent], [task])
|
||||||
|
|
||||||
|
# Use normal kickoff or kickoff_async instead of kickoff_for_each
|
||||||
|
if hasattr(crew, "kickoff_async"):
|
||||||
|
crew_output = await crew.kickoff_async(inputs={"message": message})
|
||||||
|
else:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
crew_output = await loop.run_in_executor(
|
||||||
|
None, lambda: crew.kickoff(inputs={"message": message})
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract response and add to session
|
||||||
|
final_text = extract_text_from_output(crew_output)
|
||||||
|
|
||||||
|
# Add agent response as event in session
|
||||||
|
session.events.append(
|
||||||
|
Event(
|
||||||
|
author=get_root_agent.name,
|
||||||
|
content=Content(parts=[{"text": final_text}]),
|
||||||
|
timestamp=datetime.now().timestamp(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save session with new event
|
||||||
|
session_service.save_session(session)
|
||||||
|
|
||||||
|
logger.info("Starting agent execution")
|
||||||
|
|
||||||
|
final_response_text = "No final response captured."
|
||||||
|
message_history = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
response_queue = asyncio.Queue()
|
||||||
|
execution_completed = asyncio.Event()
|
||||||
|
|
||||||
|
async def process_events():
|
||||||
|
try:
|
||||||
|
# Log the result
|
||||||
|
logger.info(f"Crew output: {crew_output}")
|
||||||
|
|
||||||
|
# Signal that execution is complete
|
||||||
|
execution_completed.set()
|
||||||
|
|
||||||
|
# Extract text from CrewOutput object
|
||||||
|
final_text = "Unable to extract a valid response."
|
||||||
|
|
||||||
|
if hasattr(crew_output, "raw") and crew_output.raw:
|
||||||
|
final_text = crew_output.raw
|
||||||
|
elif hasattr(crew_output, "__str__"):
|
||||||
|
final_text = str(crew_output)
|
||||||
|
|
||||||
|
# If still empty or None, check crew artifacts
|
||||||
|
if not final_text or final_text.strip() == "":
|
||||||
|
# Try to get from agent messages
|
||||||
|
if hasattr(root_agent, "messages") and root_agent.messages:
|
||||||
|
# Get the last message from the agent
|
||||||
|
for msg in reversed(root_agent.messages):
|
||||||
|
if hasattr(msg, "content") and msg.content:
|
||||||
|
final_text = msg.content
|
||||||
|
break
|
||||||
|
|
||||||
|
# If still empty, use a fallback
|
||||||
|
if not final_text or final_text.strip() == "":
|
||||||
|
final_text = "The agent could not produce a valid response. Please try again with a different question."
|
||||||
|
|
||||||
|
# Put the extracted text in the queue
|
||||||
|
await response_queue.put(final_text)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in process_events: {str(e)}")
|
||||||
|
# Provide a more helpful error response
|
||||||
|
error_response = f"An error occurred during processing: {str(e)}\n\nIf you are trying to use external tools such as Brave Search, please make sure the connection is working properly."
|
||||||
|
await response_queue.put(error_response)
|
||||||
|
execution_completed.set()
|
||||||
|
|
||||||
|
task = asyncio.create_task(process_events())
|
||||||
|
|
||||||
|
try:
|
||||||
|
wait_task = asyncio.create_task(execution_completed.wait())
|
||||||
|
done, pending = await asyncio.wait({wait_task}, timeout=timeout)
|
||||||
|
|
||||||
|
for p in pending:
|
||||||
|
p.cancel()
|
||||||
|
|
||||||
|
if not execution_completed.is_set():
|
||||||
|
logger.warning(
|
||||||
|
f"Agent execution timed out after {timeout} seconds"
|
||||||
|
)
|
||||||
|
await response_queue.put(
|
||||||
|
"The response took too long and was interrupted."
|
||||||
|
)
|
||||||
|
|
||||||
|
final_response_text = await response_queue.get()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error waiting for response: {str(e)}")
|
||||||
|
final_response_text = f"Error processing response: {str(e)}"
|
||||||
|
|
||||||
|
# Add the session to memory after completion
|
||||||
|
# completed_session = session_service.get_session(
|
||||||
|
# app_name=agent_id,
|
||||||
|
# user_id=external_id,
|
||||||
|
# session_id=crew_session_id,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# memory_service.add_session_to_memory(completed_session)
|
||||||
|
|
||||||
|
# Cancel the processing task if it is still running
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Task cancelled successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error cancelling task: {str(e)}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing request: {str(e)}")
|
||||||
|
raise InternalServerError(str(e)) from e
|
||||||
|
|
||||||
|
logger.info("Agent execution completed successfully")
|
||||||
|
return {
|
||||||
|
"final_response": final_response_text,
|
||||||
|
"message_history": message_history,
|
||||||
|
}
|
||||||
|
except AgentNotFoundError as e:
|
||||||
|
logger.error(f"Error processing request: {str(e)}")
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Internal error processing request: {str(e)}", exc_info=True)
|
||||||
|
raise InternalServerError(str(e))
|
||||||
|
finally:
|
||||||
|
# Clean up MCP connection - MUST be executed in the same task
|
||||||
|
if exit_stack:
|
||||||
|
logger.info("Closing MCP server connection...")
|
||||||
|
try:
|
||||||
|
if hasattr(exit_stack, "aclose"):
|
||||||
|
# If it's an AsyncExitStack
|
||||||
|
await exit_stack.aclose()
|
||||||
|
elif isinstance(exit_stack, list):
|
||||||
|
# If it's a list of adapters
|
||||||
|
for adapter in exit_stack:
|
||||||
|
if hasattr(adapter, "close"):
|
||||||
|
adapter.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing MCP connection: {e}")
|
||||||
|
# Do not raise the exception to not obscure the original error
|
||||||
|
|
||||||
|
|
||||||
|
def convert_sets(obj):
|
||||||
|
if isinstance(obj, set):
|
||||||
|
return list(obj)
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {k: convert_sets(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [convert_sets(i) for i in obj]
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
async def run_agent_stream(
|
||||||
|
agent_id: str,
|
||||||
|
external_id: str,
|
||||||
|
message: str,
|
||||||
|
db: Session,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
files: Optional[list] = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
tracer = get_tracer()
|
||||||
|
span = tracer.start_span(
|
||||||
|
"run_agent_stream",
|
||||||
|
attributes={
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"external_id": external_id,
|
||||||
|
"session_id": session_id or f"{external_id}_{agent_id}",
|
||||||
|
"message": message,
|
||||||
|
"has_files": files is not None and len(files) > 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
exit_stack = None
|
||||||
|
try:
|
||||||
|
with trace.use_span(span, end_on_exit=True):
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"Starting streaming execution of agent {agent_id} for external_id {external_id}"
|
||||||
|
)
|
||||||
|
logger.info(f"Received message: {message}")
|
||||||
|
|
||||||
|
if files and len(files) > 0:
|
||||||
|
logger.info(f"Received {len(files)} files with message")
|
||||||
|
|
||||||
|
get_root_agent = get_agent(db, agent_id)
|
||||||
|
logger.info(
|
||||||
|
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if get_root_agent is None:
|
||||||
|
raise AgentNotFoundError(f"Agent with ID {agent_id} not found")
|
||||||
|
|
||||||
|
# Using the AgentBuilder to create the agent
|
||||||
|
agent_builder = AgentBuilder(db)
|
||||||
|
result = await agent_builder.build_agent(get_root_agent)
|
||||||
|
|
||||||
|
# Check how the result is structured
|
||||||
|
if isinstance(result, tuple) and len(result) == 2:
|
||||||
|
root_agent, exit_stack = result
|
||||||
|
else:
|
||||||
|
# If the result is not a tuple of 2 elements
|
||||||
|
root_agent = result
|
||||||
|
exit_stack = None
|
||||||
|
logger.warning("build_agent did not return an exit_stack")
|
||||||
|
|
||||||
|
# TODO: files should be processed here
|
||||||
|
|
||||||
|
# Fetch session history if available
|
||||||
|
session_id = f"{external_id}_{agent_id}"
|
||||||
|
|
||||||
|
# Create an instance of the session service
|
||||||
|
try:
|
||||||
|
from src.config.settings import get_settings
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
db_url = settings.DATABASE_URL
|
||||||
|
except ImportError:
|
||||||
|
# Fallback to local SQLite if cannot import settings
|
||||||
|
db_url = "sqlite:///data/crew_sessions.db"
|
||||||
|
|
||||||
|
session_service = CrewSessionService(db_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to get existing session
|
||||||
|
session = session_service.get_session(
|
||||||
|
agent_id=agent_id,
|
||||||
|
external_id=external_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not load session: {e}")
|
||||||
|
session = None
|
||||||
|
|
||||||
|
# Build message history for context
|
||||||
|
conversation_history = []
|
||||||
|
|
||||||
|
if session and session.events:
|
||||||
|
for event in session.events:
|
||||||
|
if event.author and event.content and event.content.parts:
|
||||||
|
for part in event.content.parts:
|
||||||
|
if isinstance(part, dict) and "text" in part:
|
||||||
|
role = (
|
||||||
|
"User"
|
||||||
|
if event.author == "user"
|
||||||
|
else "Assistant"
|
||||||
|
)
|
||||||
|
conversation_history.append(
|
||||||
|
f"{role}: {part['text']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build description with history
|
||||||
|
task_description = (
|
||||||
|
f"Conversation history:\n" + "\n".join(conversation_history)
|
||||||
|
if conversation_history
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
task_description += f"\n\nCurrent user message: {message}"
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
name="resolve_user_request",
|
||||||
|
description=task_description,
|
||||||
|
expected_output="Response to the user request",
|
||||||
|
agent=root_agent,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = await agent_builder.build_crew([root_agent], [task])
|
||||||
|
|
||||||
|
logger.info("Starting agent streaming execution")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if we can process messages with kickoff_for_each
|
||||||
|
if hasattr(crew, "kickoff_for_each"):
|
||||||
|
# Create input with current message
|
||||||
|
inputs = [{"message": message}]
|
||||||
|
logger.info(
|
||||||
|
f"Using kickoff_for_each for streaming with {len(inputs)} input(s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute kickoff_for_each
|
||||||
|
results = crew.kickoff_for_each(inputs=inputs)
|
||||||
|
|
||||||
|
# Print results and save to session
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
logger.info(f"Result of event {i+1}: {result}")
|
||||||
|
|
||||||
|
# If we have a session, save the response to it
|
||||||
|
if session:
|
||||||
|
# Add agent response as event
|
||||||
|
session.events.append(
|
||||||
|
Event(
|
||||||
|
author="agent",
|
||||||
|
content=Content(parts=[{"text": result}]),
|
||||||
|
timestamp=datetime.now().timestamp(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save current session with new message
|
||||||
|
if session:
|
||||||
|
# Also add user message if it doesn't exist yet
|
||||||
|
if not any(
|
||||||
|
e.author == "user"
|
||||||
|
and any(
|
||||||
|
p.get("text") == message for p in e.content.parts
|
||||||
|
)
|
||||||
|
for e in session.events
|
||||||
|
if e.content and e.content.parts
|
||||||
|
):
|
||||||
|
session.events.append(
|
||||||
|
Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(parts=[{"text": message}]),
|
||||||
|
timestamp=datetime.now().timestamp(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Save session
|
||||||
|
try:
|
||||||
|
session_service.save_session(session)
|
||||||
|
logger.info(f"Session saved successfully: {session_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving session: {e}")
|
||||||
|
|
||||||
|
# Use last result as final output
|
||||||
|
crew_output = results[-1] if results else None
|
||||||
|
else:
|
||||||
|
# CrewAI kickoff method is synchronous, fallback if kickoff_for_each not available
|
||||||
|
logger.info(
|
||||||
|
"kickoff_for_each not available, using standard kickoff for streaming"
|
||||||
|
)
|
||||||
|
crew_output = crew.kickoff()
|
||||||
|
|
||||||
|
logger.info(f"Crew output: {crew_output}")
|
||||||
|
|
||||||
|
# Extract the actual text content
|
||||||
|
if hasattr(crew_output, "raw") and crew_output.raw:
|
||||||
|
final_output = crew_output.raw
|
||||||
|
elif hasattr(crew_output, "__str__"):
|
||||||
|
final_output = str(crew_output)
|
||||||
|
else:
|
||||||
|
final_output = "Could not extract text from response"
|
||||||
|
|
||||||
|
# Save response to session (for fallback case of normal kickoff)
|
||||||
|
if session and not hasattr(crew, "kickoff_for_each"):
|
||||||
|
# Add agent response
|
||||||
|
session.events.append(
|
||||||
|
Event(
|
||||||
|
author="agent",
|
||||||
|
content=Content(parts=[{"text": final_output}]),
|
||||||
|
timestamp=datetime.now().timestamp(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add user message if it doesn't exist yet
|
||||||
|
if not any(
|
||||||
|
e.author == "user"
|
||||||
|
and any(p.get("text") == message for p in e.content.parts)
|
||||||
|
for e in session.events
|
||||||
|
if e.content and e.content.parts
|
||||||
|
):
|
||||||
|
session.events.append(
|
||||||
|
Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(parts=[{"text": message}]),
|
||||||
|
timestamp=datetime.now().timestamp(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save session
|
||||||
|
try:
|
||||||
|
session_service.save_session(session)
|
||||||
|
logger.info(
|
||||||
|
f"Session saved successfully (method: kickoff): {session_id}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving session: {e}")
|
||||||
|
|
||||||
|
yield json.dumps({"text": final_output})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing request: {str(e)}")
|
||||||
|
raise InternalServerError(str(e)) from e
|
||||||
|
finally:
|
||||||
|
# Clean up MCP connection
|
||||||
|
if exit_stack:
|
||||||
|
logger.info("Closing MCP server connection...")
|
||||||
|
try:
|
||||||
|
if hasattr(exit_stack, "aclose"):
|
||||||
|
# If it's an AsyncExitStack
|
||||||
|
await exit_stack.aclose()
|
||||||
|
elif isinstance(exit_stack, list):
|
||||||
|
# If it's a list of adapters
|
||||||
|
for adapter in exit_stack:
|
||||||
|
if hasattr(adapter, "close"):
|
||||||
|
adapter.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing MCP connection: {e}")
|
||||||
|
# Do not raise the exception to not obscure the original error
|
||||||
|
|
||||||
|
logger.info("Agent streaming execution completed successfully")
|
||||||
|
except AgentNotFoundError as e:
|
||||||
|
logger.error(f"Error processing request: {str(e)}")
|
||||||
|
raise InternalServerError(str(e)) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Internal error processing request: {str(e)}", exc_info=True
|
||||||
|
)
|
||||||
|
raise InternalServerError(str(e))
|
||||||
|
finally:
|
||||||
|
span.end()
|
369
src/services/crewai/custom_tool.py
Normal file
369
src/services/crewai/custom_tool.py
Normal file
@ -0,0 +1,369 @@
|
|||||||
|
"""
|
||||||
|
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ @author: Davidson Gomes │
|
||||||
|
│ @file: custom_tool.py │
|
||||||
|
│ Developed by: Davidson Gomes │
|
||||||
|
│ Creation date: May 13, 2025 │
|
||||||
|
│ Contact: contato@evolution-api.com │
|
||||||
|
├──────────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ @copyright © Evolution API 2025. All rights reserved. │
|
||||||
|
│ Licensed under the Apache License, Version 2.0 │
|
||||||
|
│ │
|
||||||
|
│ You may not use this file except in compliance with the License. │
|
||||||
|
│ You may obtain a copy of the License at │
|
||||||
|
│ │
|
||||||
|
│ http://www.apache.org/licenses/LICENSE-2.0 │
|
||||||
|
│ │
|
||||||
|
│ Unless required by applicable law or agreed to in writing, software │
|
||||||
|
│ distributed under the License is distributed on an "AS IS" BASIS, │
|
||||||
|
│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │
|
||||||
|
│ See the License for the specific language governing permissions and │
|
||||||
|
│ limitations under the License. │
|
||||||
|
├──────────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ @important │
|
||||||
|
│ For any future changes to the code in this file, it is recommended to │
|
||||||
|
│ include, together with the modification, the information of the developer │
|
||||||
|
│ who changed it and the date of modification. │
|
||||||
|
└──────────────────────────────────────────────────────────────────────────────┘
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Type
|
||||||
|
from crewai.tools import BaseTool, tool
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
from src.utils.logger import setup_logger
|
||||||
|
from pydantic import BaseModel, Field, create_model
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomToolBuilder:
|
||||||
|
def __init__(self):
|
||||||
|
self.tools = []
|
||||||
|
|
||||||
|
def _create_http_tool(self, tool_config: Dict[str, Any]) -> BaseTool:
|
||||||
|
"""Create an HTTP tool based on the provided configuration."""
|
||||||
|
# Extract configuration parameters
|
||||||
|
name = tool_config["name"]
|
||||||
|
description = tool_config["description"]
|
||||||
|
endpoint = tool_config["endpoint"]
|
||||||
|
method = tool_config["method"]
|
||||||
|
headers = tool_config.get("headers", {})
|
||||||
|
parameters = tool_config.get("parameters", {}) or {}
|
||||||
|
values = tool_config.get("values", {})
|
||||||
|
error_handling = tool_config.get("error_handling", {})
|
||||||
|
|
||||||
|
path_params = parameters.get("path_params") or {}
|
||||||
|
query_params = parameters.get("query_params") or {}
|
||||||
|
body_params = parameters.get("body_params") or {}
|
||||||
|
|
||||||
|
# Dynamic creation of the input schema for the tool
|
||||||
|
field_definitions = {}
|
||||||
|
|
||||||
|
# Add all parameters as fields
|
||||||
|
for param in (
|
||||||
|
list(path_params.keys())
|
||||||
|
+ list(query_params.keys())
|
||||||
|
+ list(body_params.keys())
|
||||||
|
):
|
||||||
|
# Default to string type for all parameters
|
||||||
|
field_definitions[param] = (
|
||||||
|
str,
|
||||||
|
Field(..., description=f"Parameter {param}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# If there are no parameters but default values, use those as optional fields
|
||||||
|
if not field_definitions and values:
|
||||||
|
for param, value in values.items():
|
||||||
|
param_type = type(value)
|
||||||
|
field_definitions[param] = (
|
||||||
|
param_type,
|
||||||
|
Field(default=value, description=f"Parameter {param}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create dynamic input schema model in line with the documentation
|
||||||
|
tool_input_model = create_model(
|
||||||
|
f"{name.replace(' ', '')}Input", **field_definitions
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the HTTP tool using crewai's BaseTool class
|
||||||
|
# Following the pattern in the documentation
|
||||||
|
def create_http_tool_class():
|
||||||
|
# Capture variables from outer scope
|
||||||
|
_name = name
|
||||||
|
_description = description
|
||||||
|
_tool_input_model = tool_input_model
|
||||||
|
|
||||||
|
class HttpTool(BaseTool):
|
||||||
|
name: str = _name
|
||||||
|
description: str = _description
|
||||||
|
args_schema: Type[BaseModel] = _tool_input_model
|
||||||
|
|
||||||
|
def _run(self, **kwargs):
|
||||||
|
"""Execute the HTTP request and return the result."""
|
||||||
|
try:
|
||||||
|
# Combines default values with provided values
|
||||||
|
all_values = {**values, **kwargs}
|
||||||
|
|
||||||
|
# Substitutes placeholders in headers
|
||||||
|
processed_headers = {
|
||||||
|
k: v.format(**all_values) if isinstance(v, str) else v
|
||||||
|
for k, v in headers.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Processes path parameters
|
||||||
|
url = endpoint
|
||||||
|
for param, value in path_params.items():
|
||||||
|
if param in all_values:
|
||||||
|
url = url.replace(
|
||||||
|
f"{{{param}}}", str(all_values[param])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process query parameters
|
||||||
|
query_params_dict = {}
|
||||||
|
for param, value in query_params.items():
|
||||||
|
if isinstance(value, list):
|
||||||
|
# If the value is a list, join with comma
|
||||||
|
query_params_dict[param] = ",".join(value)
|
||||||
|
elif param in all_values:
|
||||||
|
# If the parameter is in the values, use the provided value
|
||||||
|
query_params_dict[param] = all_values[param]
|
||||||
|
else:
|
||||||
|
# Otherwise, use the default value from the configuration
|
||||||
|
query_params_dict[param] = value
|
||||||
|
|
||||||
|
# Adds default values to query params if they are not present
|
||||||
|
for param, value in values.items():
|
||||||
|
if (
|
||||||
|
param not in query_params_dict
|
||||||
|
and param not in path_params
|
||||||
|
):
|
||||||
|
query_params_dict[param] = value
|
||||||
|
|
||||||
|
body_data = {}
|
||||||
|
for param, param_config in body_params.items():
|
||||||
|
if param in all_values:
|
||||||
|
body_data[param] = all_values[param]
|
||||||
|
|
||||||
|
# Adds default values to body if they are not present
|
||||||
|
for param, value in values.items():
|
||||||
|
if (
|
||||||
|
param not in body_data
|
||||||
|
and param not in query_params_dict
|
||||||
|
and param not in path_params
|
||||||
|
):
|
||||||
|
body_data[param] = value
|
||||||
|
|
||||||
|
# Makes the HTTP request
|
||||||
|
response = requests.request(
|
||||||
|
method=method,
|
||||||
|
url=url,
|
||||||
|
headers=processed_headers,
|
||||||
|
params=query_params_dict,
|
||||||
|
json=body_data or None,
|
||||||
|
timeout=error_handling.get("timeout", 30),
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code >= 400:
|
||||||
|
raise requests.exceptions.HTTPError(
|
||||||
|
f"Error in the request: {response.status_code} - {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Always returns the response as a string
|
||||||
|
return json.dumps(response.json())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing tool {name}: {str(e)}")
|
||||||
|
return json.dumps(
|
||||||
|
error_handling.get(
|
||||||
|
"fallback_response",
|
||||||
|
{"error": "tool_execution_error", "message": str(e)},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return HttpTool
|
||||||
|
|
||||||
|
# Create the tool instance
|
||||||
|
HttpToolClass = create_http_tool_class()
|
||||||
|
http_tool = HttpToolClass()
|
||||||
|
|
||||||
|
# Add cache function following the documentation
|
||||||
|
def http_cache_function(arguments: dict, result: str) -> bool:
|
||||||
|
"""Determines whether to cache the result based on arguments and result."""
|
||||||
|
# Default implementation: cache all successful results
|
||||||
|
try:
|
||||||
|
# If the result is parseable JSON and not an error, cache it
|
||||||
|
result_obj = json.loads(result)
|
||||||
|
return not (isinstance(result_obj, dict) and "error" in result_obj)
|
||||||
|
except Exception:
|
||||||
|
# If result is not valid JSON, don't cache
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Assign the cache function to the tool
|
||||||
|
http_tool.cache_function = http_cache_function
|
||||||
|
|
||||||
|
return http_tool
|
||||||
|
|
||||||
|
def _create_http_tool_with_decorator(self, tool_config: Dict[str, Any]) -> Any:
|
||||||
|
"""Create an HTTP tool using the tool decorator."""
|
||||||
|
# Extract configuration parameters
|
||||||
|
name = tool_config["name"]
|
||||||
|
description = tool_config["description"]
|
||||||
|
endpoint = tool_config["endpoint"]
|
||||||
|
method = tool_config["method"]
|
||||||
|
headers = tool_config.get("headers", {})
|
||||||
|
parameters = tool_config.get("parameters", {}) or {}
|
||||||
|
values = tool_config.get("values", {})
|
||||||
|
error_handling = tool_config.get("error_handling", {})
|
||||||
|
|
||||||
|
path_params = parameters.get("path_params") or {}
|
||||||
|
query_params = parameters.get("query_params") or {}
|
||||||
|
body_params = parameters.get("body_params") or {}
|
||||||
|
|
||||||
|
# Create function docstring with parameter documentation
|
||||||
|
param_list = (
|
||||||
|
list(path_params.keys())
|
||||||
|
+ list(query_params.keys())
|
||||||
|
+ list(body_params.keys())
|
||||||
|
)
|
||||||
|
doc_params = []
|
||||||
|
for param in param_list:
|
||||||
|
doc_params.append(f" {param}: Parameter description")
|
||||||
|
|
||||||
|
docstring = (
|
||||||
|
f"{description}\n\nParameters:\n"
|
||||||
|
+ "\n".join(doc_params)
|
||||||
|
+ "\n\nReturns:\n String containing the response in JSON format"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the tool function using the decorator pattern in the documentation
|
||||||
|
@tool(name=name)
|
||||||
|
def http_tool(**kwargs):
|
||||||
|
"""Tool function created dynamically."""
|
||||||
|
try:
|
||||||
|
# Combines default values with provided values
|
||||||
|
all_values = {**values, **kwargs}
|
||||||
|
|
||||||
|
# Substitutes placeholders in headers
|
||||||
|
processed_headers = {
|
||||||
|
k: v.format(**all_values) if isinstance(v, str) else v
|
||||||
|
for k, v in headers.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Processes path parameters
|
||||||
|
url = endpoint
|
||||||
|
for param, value in path_params.items():
|
||||||
|
if param in all_values:
|
||||||
|
url = url.replace(f"{{{param}}}", str(all_values[param]))
|
||||||
|
|
||||||
|
# Process query parameters
|
||||||
|
query_params_dict = {}
|
||||||
|
for param, value in query_params.items():
|
||||||
|
if isinstance(value, list):
|
||||||
|
# If the value is a list, join with comma
|
||||||
|
query_params_dict[param] = ",".join(value)
|
||||||
|
elif param in all_values:
|
||||||
|
# If the parameter is in the values, use the provided value
|
||||||
|
query_params_dict[param] = all_values[param]
|
||||||
|
else:
|
||||||
|
# Otherwise, use the default value from the configuration
|
||||||
|
query_params_dict[param] = value
|
||||||
|
|
||||||
|
# Adds default values to query params if they are not present
|
||||||
|
for param, value in values.items():
|
||||||
|
if param not in query_params_dict and param not in path_params:
|
||||||
|
query_params_dict[param] = value
|
||||||
|
|
||||||
|
body_data = {}
|
||||||
|
for param, param_config in body_params.items():
|
||||||
|
if param in all_values:
|
||||||
|
body_data[param] = all_values[param]
|
||||||
|
|
||||||
|
# Adds default values to body if they are not present
|
||||||
|
for param, value in values.items():
|
||||||
|
if (
|
||||||
|
param not in body_data
|
||||||
|
and param not in query_params_dict
|
||||||
|
and param not in path_params
|
||||||
|
):
|
||||||
|
body_data[param] = value
|
||||||
|
|
||||||
|
# Makes the HTTP request
|
||||||
|
response = requests.request(
|
||||||
|
method=method,
|
||||||
|
url=url,
|
||||||
|
headers=processed_headers,
|
||||||
|
params=query_params_dict,
|
||||||
|
json=body_data or None,
|
||||||
|
timeout=error_handling.get("timeout", 30),
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code >= 400:
|
||||||
|
raise requests.exceptions.HTTPError(
|
||||||
|
f"Error in the request: {response.status_code} - {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Always returns the response as a string
|
||||||
|
return json.dumps(response.json())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing tool {name}: {str(e)}")
|
||||||
|
return json.dumps(
|
||||||
|
error_handling.get(
|
||||||
|
"fallback_response",
|
||||||
|
{"error": "tool_execution_error", "message": str(e)},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace the docstring
|
||||||
|
http_tool.__doc__ = docstring
|
||||||
|
|
||||||
|
# Add cache function following the documentation
|
||||||
|
def http_cache_function(arguments: dict, result: str) -> bool:
|
||||||
|
"""Determines whether to cache the result based on arguments and result."""
|
||||||
|
# Default implementation: cache all successful results
|
||||||
|
try:
|
||||||
|
# If the result is parseable JSON and not an error, cache it
|
||||||
|
result_obj = json.loads(result)
|
||||||
|
return not (isinstance(result_obj, dict) and "error" in result_obj)
|
||||||
|
except Exception:
|
||||||
|
# If result is not valid JSON, don't cache
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Assign the cache function to the tool
|
||||||
|
http_tool.cache_function = http_cache_function
|
||||||
|
|
||||||
|
return http_tool
|
||||||
|
|
||||||
|
def build_tools(self, tools_config: Dict[str, Any]) -> List[BaseTool]:
|
||||||
|
"""Builds a list of tools based on the provided configuration. Accepts both 'tools' and 'custom_tools' (with http_tools)."""
|
||||||
|
self.tools = []
|
||||||
|
|
||||||
|
# Find HTTP tools configuration in various possible locations
|
||||||
|
http_tools = []
|
||||||
|
if tools_config.get("http_tools"):
|
||||||
|
http_tools = tools_config.get("http_tools", [])
|
||||||
|
elif tools_config.get("custom_tools") and tools_config["custom_tools"].get(
|
||||||
|
"http_tools"
|
||||||
|
):
|
||||||
|
http_tools = tools_config["custom_tools"].get("http_tools", [])
|
||||||
|
elif (
|
||||||
|
tools_config.get("tools")
|
||||||
|
and isinstance(tools_config["tools"], dict)
|
||||||
|
and tools_config["tools"].get("http_tools")
|
||||||
|
):
|
||||||
|
http_tools = tools_config["tools"].get("http_tools", [])
|
||||||
|
|
||||||
|
# Determine which implementation method to use (BaseTool or decorator)
|
||||||
|
use_decorator = tools_config.get("use_decorator", False)
|
||||||
|
|
||||||
|
# Create tools for each HTTP tool configuration
|
||||||
|
for http_tool_config in http_tools:
|
||||||
|
if use_decorator:
|
||||||
|
self.tools.append(
|
||||||
|
self._create_http_tool_with_decorator(http_tool_config)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.tools.append(self._create_http_tool(http_tool_config))
|
||||||
|
|
||||||
|
return self.tools
|
264
src/services/crewai/mcp_service.py
Normal file
264
src/services/crewai/mcp_service.py
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
"""
|
||||||
|
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ @author: Davidson Gomes │
|
||||||
|
│ @file: mcp_service.py │
|
||||||
|
│ Developed by: Davidson Gomes │
|
||||||
|
│ Creation date: May 13, 2025 │
|
||||||
|
│ Contact: contato@evolution-api.com │
|
||||||
|
├──────────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ @copyright © Evolution API 2025. All rights reserved. │
|
||||||
|
│ Licensed under the Apache License, Version 2.0 │
|
||||||
|
│ │
|
||||||
|
│ You may not use this file except in compliance with the License. │
|
||||||
|
│ You may obtain a copy of the License at │
|
||||||
|
│ │
|
||||||
|
│ http://www.apache.org/licenses/LICENSE-2.0 │
|
||||||
|
│ │
|
||||||
|
│ Unless required by applicable law or agreed to in writing, software │
|
||||||
|
│ distributed under the License is distributed on an "AS IS" BASIS, │
|
||||||
|
│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │
|
||||||
|
│ See the License for the specific language governing permissions and │
|
||||||
|
│ limitations under the License. │
|
||||||
|
├──────────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ @important │
|
||||||
|
│ For any future changes to the code in this file, it is recommended to │
|
||||||
|
│ include, together with the modification, the information of the developer │
|
||||||
|
│ who changed it and the date of modification. │
|
||||||
|
└──────────────────────────────────────────────────────────────────────────────┘
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
from contextlib import ExitStack
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from src.utils.logger import setup_logger
|
||||||
|
from src.services.mcp_server_service import get_mcp_server
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
try:
|
||||||
|
from crewai_tools import MCPServerAdapter
|
||||||
|
from mcp import StdioServerParameters
|
||||||
|
|
||||||
|
HAS_MCP_PACKAGES = True
|
||||||
|
except ImportError:
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
logger.error(
|
||||||
|
"MCP packages are not installed. Please install mcp and crewai-tools[mcp]"
|
||||||
|
)
|
||||||
|
HAS_MCP_PACKAGES = False
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPService:
|
||||||
|
def __init__(self):
|
||||||
|
self.tools = []
|
||||||
|
self.exit_stack = ExitStack()
|
||||||
|
|
||||||
|
def _connect_to_mcp_server(
|
||||||
|
self, server_config: Dict[str, Any]
|
||||||
|
) -> Tuple[List[Any], Optional[ExitStack]]:
|
||||||
|
"""Connect to a specific MCP server and return its tools."""
|
||||||
|
if not HAS_MCP_PACKAGES:
|
||||||
|
logger.error("Cannot connect to MCP server: MCP packages not installed")
|
||||||
|
return [], None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Determines the type of server (local or remote)
|
||||||
|
if "url" in server_config:
|
||||||
|
# Remote server (SSE) - Simplified approach using direct dictionary
|
||||||
|
sse_config = {"url": server_config["url"]}
|
||||||
|
|
||||||
|
# Add headers if provided
|
||||||
|
if "headers" in server_config and server_config["headers"]:
|
||||||
|
sse_config["headers"] = server_config["headers"]
|
||||||
|
|
||||||
|
# Create the MCPServerAdapter with the SSE configuration
|
||||||
|
mcp_adapter = MCPServerAdapter(sse_config)
|
||||||
|
else:
|
||||||
|
# Local server (Stdio)
|
||||||
|
command = server_config.get("command", "npx")
|
||||||
|
args = server_config.get("args", [])
|
||||||
|
|
||||||
|
# Adds environment variables if specified
|
||||||
|
env = server_config.get("env", {})
|
||||||
|
if env:
|
||||||
|
for key, value in env.items():
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
connection_params = StdioServerParameters(
|
||||||
|
command=command, args=args, env=env
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the MCPServerAdapter with the Stdio connection parameters
|
||||||
|
mcp_adapter = MCPServerAdapter(connection_params)
|
||||||
|
|
||||||
|
# Get tools from the adapter
|
||||||
|
tools = mcp_adapter.tools
|
||||||
|
|
||||||
|
# Return tools and the adapter (which serves as an exit stack)
|
||||||
|
return tools, mcp_adapter
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error connecting to MCP server: {e}")
|
||||||
|
return [], None
|
||||||
|
|
||||||
|
def _filter_incompatible_tools(self, tools: List[Any]) -> List[Any]:
|
||||||
|
"""Filters incompatible tools with the model."""
|
||||||
|
problematic_tools = [
|
||||||
|
"create_pull_request_review", # This tool causes the 400 INVALID_ARGUMENT error
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered_tools = []
|
||||||
|
removed_count = 0
|
||||||
|
|
||||||
|
for tool in tools:
|
||||||
|
if tool.name in problematic_tools:
|
||||||
|
logger.warning(f"Removing incompatible tool: {tool.name}")
|
||||||
|
removed_count += 1
|
||||||
|
else:
|
||||||
|
filtered_tools.append(tool)
|
||||||
|
|
||||||
|
if removed_count > 0:
|
||||||
|
logger.warning(f"Removed {removed_count} incompatible tools.")
|
||||||
|
|
||||||
|
return filtered_tools
|
||||||
|
|
||||||
|
def _filter_tools_by_agent(
|
||||||
|
self, tools: List[Any], agent_tools: List[str]
|
||||||
|
) -> List[Any]:
|
||||||
|
"""Filters tools compatible with the agent."""
|
||||||
|
if not agent_tools:
|
||||||
|
return tools
|
||||||
|
|
||||||
|
filtered_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
logger.info(f"Tool: {tool.name}")
|
||||||
|
if tool.name in agent_tools:
|
||||||
|
filtered_tools.append(tool)
|
||||||
|
return filtered_tools
|
||||||
|
|
||||||
|
async def build_tools(
|
||||||
|
self, mcp_config: Dict[str, Any], db: Session
|
||||||
|
) -> Tuple[List[Any], Any]:
|
||||||
|
"""Builds a list of tools from multiple MCP servers."""
|
||||||
|
if not HAS_MCP_PACKAGES:
|
||||||
|
logger.error("Cannot build MCP tools: MCP packages not installed")
|
||||||
|
return [], None
|
||||||
|
|
||||||
|
self.tools = []
|
||||||
|
self.exit_stack = ExitStack()
|
||||||
|
adapter_list = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
mcp_servers = mcp_config.get("mcp_servers", [])
|
||||||
|
if mcp_servers is not None:
|
||||||
|
# Process each MCP server in the configuration
|
||||||
|
for server in mcp_servers:
|
||||||
|
try:
|
||||||
|
# Search for the MCP server in the database
|
||||||
|
mcp_server = get_mcp_server(db, server["id"])
|
||||||
|
if not mcp_server:
|
||||||
|
logger.warning(f"MCP Server not found: {server['id']}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prepares the server configuration
|
||||||
|
server_config = mcp_server.config_json.copy()
|
||||||
|
|
||||||
|
# Replaces the environment variables in the config_json
|
||||||
|
if "env" in server_config and server_config["env"] is not None:
|
||||||
|
for key, value in server_config["env"].items():
|
||||||
|
if value and value.startswith("env@@"):
|
||||||
|
env_key = value.replace("env@@", "")
|
||||||
|
if server.get("envs") and env_key in server.get(
|
||||||
|
"envs", {}
|
||||||
|
):
|
||||||
|
server_config["env"][key] = server["envs"][
|
||||||
|
env_key
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"Connecting to MCP server: {mcp_server.name}")
|
||||||
|
tools, adapter = self._connect_to_mcp_server(server_config)
|
||||||
|
|
||||||
|
if tools and adapter:
|
||||||
|
# Filters incompatible tools
|
||||||
|
filtered_tools = self._filter_incompatible_tools(tools)
|
||||||
|
|
||||||
|
# Filters tools compatible with the agent
|
||||||
|
if agent_tools := server.get("tools", []):
|
||||||
|
filtered_tools = self._filter_tools_by_agent(
|
||||||
|
filtered_tools, agent_tools
|
||||||
|
)
|
||||||
|
self.tools.extend(filtered_tools)
|
||||||
|
|
||||||
|
# Add to the adapter list for cleanup later
|
||||||
|
adapter_list.append(adapter)
|
||||||
|
logger.info(
|
||||||
|
f"MCP Server {mcp_server.name} connected successfully. Added {len(filtered_tools)} tools."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to connect or no tools available for {mcp_server.name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error connecting to MCP server {server.get('id', 'unknown')}: {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
custom_mcp_servers = mcp_config.get("custom_mcp_servers", [])
|
||||||
|
if custom_mcp_servers is not None:
|
||||||
|
# Process custom MCP servers
|
||||||
|
for server in custom_mcp_servers:
|
||||||
|
if not server:
|
||||||
|
logger.warning(
|
||||||
|
"Empty server configuration found in custom_mcp_servers"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"Connecting to custom MCP server: {server.get('url', 'unknown')}"
|
||||||
|
)
|
||||||
|
tools, adapter = self._connect_to_mcp_server(server)
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
self.tools.extend(tools)
|
||||||
|
else:
|
||||||
|
logger.warning("No tools returned from custom MCP server")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if adapter:
|
||||||
|
adapter_list.append(adapter)
|
||||||
|
logger.info(
|
||||||
|
f"Custom MCP server connected successfully. Added {len(tools)} tools."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("No adapter returned from custom MCP server")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error connecting to custom MCP server {server.get('url', 'unknown')}: {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"MCP Toolset created successfully. Total of {len(self.tools)} tools."
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Ensure cleanup
|
||||||
|
for adapter in adapter_list:
|
||||||
|
if hasattr(adapter, "close"):
|
||||||
|
adapter.close()
|
||||||
|
logger.error(f"Fatal error connecting to MCP servers: {e}")
|
||||||
|
# Return empty lists in case of error
|
||||||
|
return [], None
|
||||||
|
|
||||||
|
# Return the tools and the adapter list for cleanup
|
||||||
|
return self.tools, adapter_list
|
637
src/services/crewai/session_service.py
Normal file
637
src/services/crewai/session_service.py
Normal file
@ -0,0 +1,637 @@
|
|||||||
|
"""
|
||||||
|
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ @author: Davidson Gomes │
|
||||||
|
│ @file: session_service.py │
|
||||||
|
│ Developed by: Davidson Gomes │
|
||||||
|
│ Creation date: May 13, 2025 │
|
||||||
|
│ Contact: contato@evolution-api.com │
|
||||||
|
├──────────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ @copyright © Evolution API 2025. All rights reserved. │
|
||||||
|
│ Licensed under the Apache License, Version 2.0 │
|
||||||
|
│ │
|
||||||
|
│ You may not use this file except in compliance with the License. │
|
||||||
|
│ You may obtain a copy of the License at │
|
||||||
|
│ │
|
||||||
|
│ http://www.apache.org/licenses/LICENSE-2.0 │
|
||||||
|
│ │
|
||||||
|
│ Unless required by applicable law or agreed to in writing, software │
|
||||||
|
│ distributed under the License is distributed on an "AS IS" BASIS, │
|
||||||
|
│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │
|
||||||
|
│ See the License for the specific language governing permissions and │
|
||||||
|
│ limitations under the License. │
|
||||||
|
├──────────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ @important │
|
||||||
|
│ For any future changes to the code in this file, it is recommended to │
|
||||||
|
│ include, together with the modification, the information of the developer │
|
||||||
|
│ who changed it and the date of modification. │
|
||||||
|
└──────────────────────────────────────────────────────────────────────────────┘
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
import base64
|
||||||
|
import copy
|
||||||
|
from typing import Any, Dict, List, Optional, Union, Set
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine, Boolean, Text, ForeignKeyConstraint
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.ext.mutable import MutableDict
|
||||||
|
from sqlalchemy.orm import (
|
||||||
|
sessionmaker,
|
||||||
|
relationship,
|
||||||
|
DeclarativeBase,
|
||||||
|
Mapped,
|
||||||
|
mapped_column,
|
||||||
|
)
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
from sqlalchemy.types import DateTime, PickleType, String
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
from sqlalchemy.types import TypeDecorator
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from src.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicJSON(TypeDecorator):
|
||||||
|
"""JSON type compatible with ADK that uses JSONB in PostgreSQL and TEXT with JSON
|
||||||
|
serialization for other databases."""
|
||||||
|
|
||||||
|
impl = Text # Default implementation is TEXT
|
||||||
|
|
||||||
|
def load_dialect_impl(self, dialect):
|
||||||
|
if dialect.name == "postgresql":
|
||||||
|
return dialect.type_descriptor(postgresql.JSONB)
|
||||||
|
else:
|
||||||
|
return dialect.type_descriptor(Text)
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
if value is not None:
|
||||||
|
if dialect.name == "postgresql":
|
||||||
|
return value
|
||||||
|
else:
|
||||||
|
return json.dumps(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
if value is not None:
|
||||||
|
if dialect.name == "postgresql":
|
||||||
|
return value
|
||||||
|
else:
|
||||||
|
return json.loads(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
"""Base class for database tables."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StorageSession(Base):
|
||||||
|
"""Represents a session stored in the database, compatible with ADK."""
|
||||||
|
|
||||||
|
__tablename__ = "sessions"
|
||||||
|
|
||||||
|
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||||
|
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
String, primary_key=True, default=lambda: str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
|
||||||
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||||
|
MutableDict.as_mutable(DynamicJSON), default={}
|
||||||
|
)
|
||||||
|
|
||||||
|
create_time: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
||||||
|
update_time: Mapped[DateTime] = mapped_column(
|
||||||
|
DateTime(), default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
storage_events: Mapped[list["StorageEvent"]] = relationship(
|
||||||
|
"StorageEvent",
|
||||||
|
back_populates="storage_session",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<StorageSession(id={self.id}, update_time={self.update_time})>"
|
||||||
|
|
||||||
|
|
||||||
|
class StorageEvent(Base):
|
||||||
|
"""Represents an event stored in the database, compatible with ADK."""
|
||||||
|
|
||||||
|
__tablename__ = "events"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||||
|
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||||
|
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||||
|
session_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||||
|
|
||||||
|
invocation_id: Mapped[str] = mapped_column(String)
|
||||||
|
author: Mapped[str] = mapped_column(String)
|
||||||
|
branch: Mapped[str] = mapped_column(String, nullable=True)
|
||||||
|
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
||||||
|
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
|
||||||
|
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
|
||||||
|
|
||||||
|
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
|
||||||
|
Text, nullable=True
|
||||||
|
)
|
||||||
|
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||||
|
DynamicJSON, nullable=True
|
||||||
|
)
|
||||||
|
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||||
|
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||||
|
error_code: Mapped[str] = mapped_column(String, nullable=True)
|
||||||
|
error_message: Mapped[str] = mapped_column(String, nullable=True)
|
||||||
|
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||||
|
|
||||||
|
storage_session: Mapped[StorageSession] = relationship(
|
||||||
|
"StorageSession",
|
||||||
|
back_populates="storage_events",
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
ForeignKeyConstraint(
|
||||||
|
["app_name", "user_id", "session_id"],
|
||||||
|
["sessions.app_name", "sessions.user_id", "sessions.id"],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def long_running_tool_ids(self) -> set[str]:
|
||||||
|
return (
|
||||||
|
set(json.loads(self.long_running_tool_ids_json))
|
||||||
|
if self.long_running_tool_ids_json
|
||||||
|
else set()
|
||||||
|
)
|
||||||
|
|
||||||
|
@long_running_tool_ids.setter
|
||||||
|
def long_running_tool_ids(self, value: set[str]):
|
||||||
|
if value is None:
|
||||||
|
self.long_running_tool_ids_json = None
|
||||||
|
else:
|
||||||
|
self.long_running_tool_ids_json = json.dumps(list(value))
|
||||||
|
|
||||||
|
|
||||||
|
class StorageAppState(Base):
|
||||||
|
"""Represents an application state stored in the database, compatible with ADK."""
|
||||||
|
|
||||||
|
__tablename__ = "app_states"
|
||||||
|
|
||||||
|
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||||
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||||
|
MutableDict.as_mutable(DynamicJSON), default={}
|
||||||
|
)
|
||||||
|
update_time: Mapped[DateTime] = mapped_column(
|
||||||
|
DateTime(), default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StorageUserState(Base):
|
||||||
|
"""Represents a user state stored in the database, compatible with ADK."""
|
||||||
|
|
||||||
|
__tablename__ = "user_states"
|
||||||
|
|
||||||
|
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||||
|
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||||
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||||
|
MutableDict.as_mutable(DynamicJSON), default={}
|
||||||
|
)
|
||||||
|
update_time: Mapped[DateTime] = mapped_column(
|
||||||
|
DateTime(), default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Pydantic model classes compatible with ADK
|
||||||
|
class State:
|
||||||
|
"""Utility class for states, compatible with ADK."""
|
||||||
|
|
||||||
|
APP_PREFIX = "app:"
|
||||||
|
USER_PREFIX = "user:"
|
||||||
|
TEMP_PREFIX = "temp:"
|
||||||
|
|
||||||
|
|
||||||
|
class Content(BaseModel):
|
||||||
|
"""Event content model, compatible with ADK."""
|
||||||
|
|
||||||
|
parts: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class Part(BaseModel):
|
||||||
|
"""Content part model, compatible with ADK."""
|
||||||
|
|
||||||
|
text: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Event(BaseModel):
|
||||||
|
"""Event model, compatible with ADK."""
|
||||||
|
|
||||||
|
id: Optional[str] = None
|
||||||
|
author: str
|
||||||
|
branch: Optional[str] = None
|
||||||
|
invocation_id: Optional[str] = None
|
||||||
|
content: Optional[Content] = None
|
||||||
|
actions: Optional[Dict[str, Any]] = None
|
||||||
|
timestamp: Optional[float] = None
|
||||||
|
long_running_tool_ids: Optional[Set[str]] = None
|
||||||
|
grounding_metadata: Optional[Dict[str, Any]] = None
|
||||||
|
partial: Optional[bool] = None
|
||||||
|
turn_complete: Optional[bool] = None
|
||||||
|
error_code: Optional[str] = None
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
interrupted: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Session(BaseModel):
|
||||||
|
"""Session model, compatible with ADK."""
|
||||||
|
|
||||||
|
app_name: str
|
||||||
|
user_id: str
|
||||||
|
id: str
|
||||||
|
state: Dict[str, Any] = {}
|
||||||
|
events: List[Event] = []
|
||||||
|
last_update_time: float
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class CrewSessionService:
|
||||||
|
"""Service for managing CrewAI agent sessions using ADK tables."""
|
||||||
|
|
||||||
|
def __init__(self, db_url: str):
|
||||||
|
"""
|
||||||
|
Initializes the session service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_url: Database connection URL.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.engine = create_engine(db_url)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to create database engine: {e}")
|
||||||
|
|
||||||
|
# Create all tables
|
||||||
|
Base.metadata.create_all(self.engine)
|
||||||
|
self.Session = sessionmaker(bind=self.engine)
|
||||||
|
logger.info(f"CrewSessionService started with database at {db_url}")
|
||||||
|
|
||||||
|
def create_session(
|
||||||
|
self, agent_id: str, external_id: str, session_id: Optional[str] = None
|
||||||
|
) -> Session:
|
||||||
|
"""
|
||||||
|
Creates a new session for an agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID (used as app_name in ADK)
|
||||||
|
external_id: External ID (used as user_id in ADK)
|
||||||
|
session_id: Optional session ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Session: The created session
|
||||||
|
"""
|
||||||
|
session_id = session_id or str(uuid.uuid4())
|
||||||
|
|
||||||
|
with self.Session() as db_session:
|
||||||
|
# Check if app and user states already exist
|
||||||
|
storage_app_state = db_session.get(StorageAppState, (agent_id))
|
||||||
|
storage_user_state = db_session.get(
|
||||||
|
StorageUserState, (agent_id, external_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
app_state = storage_app_state.state if storage_app_state else {}
|
||||||
|
user_state = storage_user_state.state if storage_user_state else {}
|
||||||
|
|
||||||
|
# Create states if they don't exist
|
||||||
|
if not storage_app_state:
|
||||||
|
storage_app_state = StorageAppState(app_name=agent_id, state={})
|
||||||
|
db_session.add(storage_app_state)
|
||||||
|
|
||||||
|
if not storage_user_state:
|
||||||
|
storage_user_state = StorageUserState(
|
||||||
|
app_name=agent_id, user_id=external_id, state={}
|
||||||
|
)
|
||||||
|
db_session.add(storage_user_state)
|
||||||
|
|
||||||
|
# Create session
|
||||||
|
storage_session = StorageSession(
|
||||||
|
app_name=agent_id,
|
||||||
|
user_id=external_id,
|
||||||
|
id=session_id,
|
||||||
|
state={},
|
||||||
|
)
|
||||||
|
db_session.add(storage_session)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Get timestamp
|
||||||
|
db_session.refresh(storage_session)
|
||||||
|
|
||||||
|
# Merge states for response
|
||||||
|
merged_state = _merge_state(app_state, user_state, {})
|
||||||
|
|
||||||
|
# Create Session object for return
|
||||||
|
session = Session(
|
||||||
|
app_name=agent_id,
|
||||||
|
user_id=external_id,
|
||||||
|
id=session_id,
|
||||||
|
state=merged_state,
|
||||||
|
last_update_time=storage_session.update_time.timestamp(),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Session created: {session_id} for agent {agent_id} and user {external_id}"
|
||||||
|
)
|
||||||
|
return session
|
||||||
|
|
||||||
|
def get_session(
|
||||||
|
self, agent_id: str, external_id: str, session_id: str
|
||||||
|
) -> Optional[Session]:
|
||||||
|
"""
|
||||||
|
Retrieves a session from the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
external_id: User ID
|
||||||
|
session_id: Session ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Session]: The retrieved session or None if not found
|
||||||
|
"""
|
||||||
|
with self.Session() as db_session:
|
||||||
|
storage_session = db_session.get(
|
||||||
|
StorageSession, (agent_id, external_id, session_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
if storage_session is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Fetch session events
|
||||||
|
storage_events = (
|
||||||
|
db_session.query(StorageEvent)
|
||||||
|
.filter(StorageEvent.session_id == storage_session.id)
|
||||||
|
.filter(StorageEvent.app_name == agent_id)
|
||||||
|
.filter(StorageEvent.user_id == external_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch states
|
||||||
|
storage_app_state = db_session.get(StorageAppState, (agent_id))
|
||||||
|
storage_user_state = db_session.get(
|
||||||
|
StorageUserState, (agent_id, external_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
app_state = storage_app_state.state if storage_app_state else {}
|
||||||
|
user_state = storage_user_state.state if storage_user_state else {}
|
||||||
|
session_state = storage_session.state
|
||||||
|
|
||||||
|
# Merge states
|
||||||
|
merged_state = _merge_state(app_state, user_state, session_state)
|
||||||
|
|
||||||
|
# Create session
|
||||||
|
session = Session(
|
||||||
|
app_name=agent_id,
|
||||||
|
user_id=external_id,
|
||||||
|
id=session_id,
|
||||||
|
state=merged_state,
|
||||||
|
last_update_time=storage_session.update_time.timestamp(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add events
|
||||||
|
session.events = [
|
||||||
|
Event(
|
||||||
|
id=e.id,
|
||||||
|
author=e.author,
|
||||||
|
branch=e.branch,
|
||||||
|
invocation_id=e.invocation_id,
|
||||||
|
content=_decode_content(e.content),
|
||||||
|
actions=e.actions,
|
||||||
|
timestamp=e.timestamp.timestamp(),
|
||||||
|
long_running_tool_ids=e.long_running_tool_ids,
|
||||||
|
grounding_metadata=e.grounding_metadata,
|
||||||
|
partial=e.partial,
|
||||||
|
turn_complete=e.turn_complete,
|
||||||
|
error_code=e.error_code,
|
||||||
|
error_message=e.error_message,
|
||||||
|
interrupted=e.interrupted,
|
||||||
|
)
|
||||||
|
for e in storage_events
|
||||||
|
]
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
def save_session(self, session: Session) -> None:
|
||||||
|
"""
|
||||||
|
Saves a session to the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: The session to save
|
||||||
|
"""
|
||||||
|
with self.Session() as db_session:
|
||||||
|
storage_session = db_session.get(
|
||||||
|
StorageSession, (session.app_name, session.user_id, session.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not storage_session:
|
||||||
|
logger.error(f"Session not found: {session.id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check states
|
||||||
|
storage_app_state = db_session.get(StorageAppState, (session.app_name))
|
||||||
|
storage_user_state = db_session.get(
|
||||||
|
StorageUserState, (session.app_name, session.user_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract state deltas
|
||||||
|
app_state_delta = {}
|
||||||
|
user_state_delta = {}
|
||||||
|
session_state_delta = {}
|
||||||
|
|
||||||
|
# Apply state deltas
|
||||||
|
if storage_app_state and app_state_delta:
|
||||||
|
storage_app_state.state.update(app_state_delta)
|
||||||
|
|
||||||
|
if storage_user_state and user_state_delta:
|
||||||
|
storage_user_state.state.update(user_state_delta)
|
||||||
|
|
||||||
|
storage_session.state.update(session_state_delta)
|
||||||
|
|
||||||
|
# Save new events
|
||||||
|
for event in session.events:
|
||||||
|
# Check if event already exists
|
||||||
|
existing_event = (
|
||||||
|
(
|
||||||
|
db_session.query(StorageEvent)
|
||||||
|
.filter(StorageEvent.id == event.id)
|
||||||
|
.filter(StorageEvent.app_name == session.app_name)
|
||||||
|
.filter(StorageEvent.user_id == session.user_id)
|
||||||
|
.filter(StorageEvent.session_id == session.id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if event.id
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_event:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Generate ID for the event if it doesn't exist
|
||||||
|
if not event.id:
|
||||||
|
event.id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Create timestamp if it doesn't exist
|
||||||
|
if not event.timestamp:
|
||||||
|
event.timestamp = datetime.now().timestamp()
|
||||||
|
|
||||||
|
# Create StorageEvent object
|
||||||
|
storage_event = StorageEvent(
|
||||||
|
id=event.id,
|
||||||
|
app_name=session.app_name,
|
||||||
|
user_id=session.user_id,
|
||||||
|
session_id=session.id,
|
||||||
|
invocation_id=event.invocation_id or str(uuid.uuid4()),
|
||||||
|
author=event.author,
|
||||||
|
branch=event.branch,
|
||||||
|
timestamp=datetime.fromtimestamp(event.timestamp),
|
||||||
|
actions=event.actions or {},
|
||||||
|
long_running_tool_ids=event.long_running_tool_ids or set(),
|
||||||
|
grounding_metadata=event.grounding_metadata,
|
||||||
|
partial=event.partial,
|
||||||
|
turn_complete=event.turn_complete,
|
||||||
|
error_code=event.error_code,
|
||||||
|
error_message=event.error_message,
|
||||||
|
interrupted=event.interrupted,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Encode content, if it exists
|
||||||
|
if event.content:
|
||||||
|
encoded_content = event.content.model_dump(exclude_none=True)
|
||||||
|
# Solution for serialization issues with multimedia content
|
||||||
|
for p in encoded_content.get("parts", []):
|
||||||
|
if "inline_data" in p:
|
||||||
|
p["inline_data"]["data"] = (
|
||||||
|
base64.b64encode(p["inline_data"]["data"]).decode(
|
||||||
|
"utf-8"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
storage_event.content = encoded_content
|
||||||
|
|
||||||
|
db_session.add(storage_event)
|
||||||
|
|
||||||
|
# Commit changes
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Update timestamp in session
|
||||||
|
db_session.refresh(storage_session)
|
||||||
|
session.last_update_time = storage_session.update_time.timestamp()
|
||||||
|
|
||||||
|
logger.info(f"Session saved: {session.id} with {len(session.events)} events")
|
||||||
|
|
||||||
|
def list_sessions(self, agent_id: str, external_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Lists all sessions for an agent and user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
external_id: User ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, Any]]: List of summarized sessions
|
||||||
|
"""
|
||||||
|
with self.Session() as db_session:
|
||||||
|
sessions = (
|
||||||
|
db_session.query(StorageSession)
|
||||||
|
.filter(StorageSession.app_name == agent_id)
|
||||||
|
.filter(StorageSession.user_id == external_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for session in sessions:
|
||||||
|
result.append(
|
||||||
|
{
|
||||||
|
"app_name": session.app_name,
|
||||||
|
"user_id": session.user_id,
|
||||||
|
"id": session.id,
|
||||||
|
"created_at": session.create_time.isoformat(),
|
||||||
|
"updated_at": session.update_time.isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def delete_session(self, agent_id: str, external_id: str, session_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Deletes a session from the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
external_id: User ID
|
||||||
|
session_id: Session ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the session was deleted, False otherwise
|
||||||
|
"""
|
||||||
|
from sqlalchemy import delete
|
||||||
|
|
||||||
|
with self.Session() as db_session:
|
||||||
|
stmt = delete(StorageSession).where(
|
||||||
|
StorageSession.app_name == agent_id,
|
||||||
|
StorageSession.user_id == external_id,
|
||||||
|
StorageSession.id == session_id,
|
||||||
|
)
|
||||||
|
result = db_session.execute(stmt)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
logger.info(f"Session deleted: {session_id}")
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
|
||||||
|
# Utility functions compatible with ADK
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_state_delta(state: dict[str, Any]):
|
||||||
|
"""Extracts state deltas between app, user, and session."""
|
||||||
|
app_state_delta = {}
|
||||||
|
user_state_delta = {}
|
||||||
|
session_state_delta = {}
|
||||||
|
|
||||||
|
if state:
|
||||||
|
for key in state.keys():
|
||||||
|
if key.startswith(State.APP_PREFIX):
|
||||||
|
app_state_delta[key.removeprefix(State.APP_PREFIX)] = state[key]
|
||||||
|
elif key.startswith(State.USER_PREFIX):
|
||||||
|
user_state_delta[key.removeprefix(State.USER_PREFIX)] = state[key]
|
||||||
|
elif not key.startswith(State.TEMP_PREFIX):
|
||||||
|
session_state_delta[key] = state[key]
|
||||||
|
|
||||||
|
return app_state_delta, user_state_delta, session_state_delta
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_state(app_state, user_state, session_state):
|
||||||
|
"""Merges app, user, and session states into a single object."""
|
||||||
|
merged_state = copy.deepcopy(session_state)
|
||||||
|
|
||||||
|
for key in app_state.keys():
|
||||||
|
merged_state[State.APP_PREFIX + key] = app_state[key]
|
||||||
|
|
||||||
|
for key in user_state.keys():
|
||||||
|
merged_state[State.USER_PREFIX + key] = user_state[key]
|
||||||
|
|
||||||
|
return merged_state
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_content(content: Optional[dict[str, Any]]) -> Optional[Content]:
|
||||||
|
"""Decodes event content potentially with binary data."""
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for p in content.get("parts", []):
|
||||||
|
if "inline_data" in p and isinstance(p["inline_data"].get("data"), tuple):
|
||||||
|
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
|
||||||
|
|
||||||
|
return Content.model_validate(content)
|
@ -27,12 +27,22 @@
|
|||||||
└──────────────────────────────────────────────────────────────────────────────┘
|
└──────────────────────────────────────────────────────────────────────────────┘
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.config.settings import settings
|
import os
|
||||||
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||||
from google.adk.sessions import DatabaseSessionService
|
from google.adk.sessions import DatabaseSessionService
|
||||||
from google.adk.memory import InMemoryMemoryService
|
from google.adk.memory import InMemoryMemoryService
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
from src.services.crewai.session_service import CrewSessionService
|
||||||
|
|
||||||
|
if os.getenv("AI_ENGINE") == "crewai":
|
||||||
|
session_service = CrewSessionService(db_url=os.getenv("POSTGRES_CONNECTION_STRING"))
|
||||||
|
else:
|
||||||
|
session_service = DatabaseSessionService(
|
||||||
|
db_url=os.getenv("POSTGRES_CONNECTION_STRING")
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize service instances
|
|
||||||
session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING)
|
|
||||||
artifacts_service = InMemoryArtifactService()
|
artifacts_service = InMemoryArtifactService()
|
||||||
memory_service = InMemoryMemoryService()
|
memory_service = InMemoryMemoryService()
|
||||||
|
Loading…
Reference in New Issue
Block a user