feat(api): integrate new AI engines and update chat routes for dynamic agent handling

This commit is contained in:
Davidson Gomes
2025-05-19 15:22:37 -03:00
parent 9f176bf0e0
commit cf24a7ce5d
22 changed files with 2153 additions and 37 deletions

View File

@@ -51,6 +51,9 @@ dependencies = [
"langgraph==0.4.1",
"opentelemetry-sdk==1.33.0",
"opentelemetry-exporter-otlp==1.33.0",
"mcp==1.9.0",
"crewai==0.120.1",
"crewai-tools==0.45.0",
]
[project.optional-dependencies]

View File

@@ -39,6 +39,7 @@ from fastapi import (
Header,
)
from sqlalchemy.orm import Session
from src.config.settings import settings
from src.config.database import get_db
from src.core.jwt_middleware import (
get_jwt_token,
@@ -49,7 +50,8 @@ from src.services import (
agent_service,
)
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse, FileData
from src.services.agent_runner import run_agent, run_agent_stream
from src.services.adk.agent_runner import run_agent as run_agent_adk, run_agent_stream
from src.services.crewai.agent_runner import run_agent as run_agent_crewai
from src.core.exceptions import AgentNotFoundError
from src.services.service_providers import (
session_service,
@@ -262,7 +264,7 @@ async def websocket_chat(
@router.post(
"",
"/{agent_id}/{external_id}",
response_model=ChatResponse,
responses={
400: {"model": ErrorResponse},
@@ -272,20 +274,32 @@ async def websocket_chat(
)
async def chat(
request: ChatRequest,
agent_id: str,
external_id: str,
_=Depends(get_agent_by_api_key),
db: Session = Depends(get_db),
):
try:
final_response = await run_agent(
request.agent_id,
request.external_id,
request.message,
session_service,
artifacts_service,
memory_service,
db,
files=request.files,
)
if settings.AI_ENGINE == "adk":
final_response = await run_agent_adk(
agent_id,
external_id,
request.message,
session_service,
artifacts_service,
memory_service,
db,
files=request.files,
)
elif settings.AI_ENGINE == "crewai":
final_response = await run_agent_crewai(
agent_id,
external_id,
request.message,
session_service,
db,
files=request.files,
)
return {
"response": final_response["final_response"],

View File

@@ -0,0 +1,3 @@
from src.config.settings import settings
__all__ = ["settings"]

View File

@@ -57,6 +57,9 @@ class Settings(BaseSettings):
"POSTGRES_CONNECTION_STRING", "postgresql://postgres:root@localhost:5432/evo_ai"
)
# AI engine settings
AI_ENGINE: str = os.getenv("AI_ENGINE", "adk")
# Logging settings
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
LOG_DIR: str = "logs"

View File

@@ -43,9 +43,11 @@ class FileData(BaseModel):
class ChatRequest(BaseModel):
"""Model to represent a chat request."""
agent_id: str = Field(..., description="Agent ID to process the message")
external_id: str = Field(..., description="External ID for user identification")
message: str = Field(..., description="User message to the agent")
agent_id: Optional[str] = Field(None, description="Agent ID to process the message")
external_id: Optional[str] = Field(
None, description="External ID for user identification"
)
files: Optional[List[FileData]] = Field(
None, description="List of files attached to the message"
)

View File

@@ -1 +1 @@
from .agent_runner import run_agent
from .adk.agent_runner import run_agent

View File

@@ -45,7 +45,7 @@ from src.services.agent_service import (
)
from src.services.mcp_server_service import get_mcp_server
from src.services.agent_runner import run_agent, run_agent_stream
from src.services.adk.agent_runner import run_agent, run_agent_stream
from src.services.service_providers import (
session_service,
artifacts_service,
@@ -388,7 +388,6 @@ class A2ATaskManager:
self, request: SendTaskStreamingRequest, agent: Agent
) -> AsyncIterable[SendTaskStreamingResponse]:
"""Processes a task in streaming mode using the specified agent."""
# Extrair e processar arquivos da mesma forma que no método _process_task
query = self._extract_user_query(request.params)
try:
@@ -448,21 +447,19 @@ class A2ATaskManager:
),
)
# Use os arquivos processados do _extract_user_query
files = getattr(self, "_last_processed_files", None)
# Log sobre os arquivos processados
if files:
logger.info(
f"Streaming: Passando {len(files)} arquivos processados para run_agent_stream"
f"Streaming: Uploading {len(files)} files to run_agent_stream"
)
for file_info in files:
logger.info(
f"Streaming: Arquivo sendo enviado: {file_info.filename} ({file_info.content_type})"
f"Streaming: File being sent: {file_info.filename} ({file_info.content_type})"
)
else:
logger.warning(
"Streaming: Nenhum arquivo processado disponível para enviar ao agente"
"Streaming: No processed files available to send to the agent"
)
async for chunk in run_agent_stream(
@@ -473,7 +470,7 @@ class A2ATaskManager:
artifacts_service=artifacts_service,
memory_service=memory_service,
db=self.db,
files=files, # Passar os arquivos processados para o streaming
files=files,
):
try:
chunk_data = json.loads(chunk)

View File

@@ -36,11 +36,11 @@ from src.schemas.schemas import Agent
from src.utils.logger import setup_logger
from src.core.exceptions import AgentNotFoundError
from src.services.agent_service import get_agent
from src.services.custom_tools import CustomToolBuilder
from src.services.mcp_service import MCPService
from src.services.custom_agents.a2a_agent import A2ACustomAgent
from src.services.custom_agents.workflow_agent import WorkflowAgent
from src.services.custom_agents.task_agent import TaskAgent
from src.services.adk.custom_tools import CustomToolBuilder
from src.services.adk.mcp_service import MCPService
from src.services.adk.custom_agents.a2a_agent import A2ACustomAgent
from src.services.adk.custom_agents.workflow_agent import WorkflowAgent
from src.services.adk.custom_agents.task_agent import TaskAgent
from src.services.apikey_service import get_decrypted_api_key
from sqlalchemy.orm import Session
from contextlib import AsyncExitStack

View File

@@ -35,7 +35,7 @@ from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactServ
from src.utils.logger import setup_logger
from src.core.exceptions import AgentNotFoundError, InternalServerError
from src.services.agent_service import get_agent
from src.services.agent_builder import AgentBuilder
from src.services.adk.agent_builder import AgentBuilder
from sqlalchemy.orm import Session
from typing import Optional, AsyncGenerator
import asyncio

View File

@@ -162,7 +162,7 @@ class TaskAgent(BaseAgent):
),
)
from src.services.agent_builder import AgentBuilder
from src.services.adk.agent_builder import AgentBuilder
print(f"Building agent in Task agent: {agent.name}")
agent_builder = AgentBuilder(self.db)

View File

@@ -181,7 +181,7 @@ class WorkflowAgent(BaseAgent):
return
# Import moved to inside the function to avoid circular import
from src.services.agent_builder import AgentBuilder
from src.services.adk.agent_builder import AgentBuilder
agent_builder = AgentBuilder(self.db)
root_agent, exit_stack = await agent_builder.build_agent(agent)

View File

@@ -0,0 +1,219 @@
"""
┌──────────────────────────────────────────────────────────────────────────────┐
│ @author: Davidson Gomes │
│ @file: agent_builder.py │
│ Developed by: Davidson Gomes │
│ Creation date: May 13, 2025 │
│ Contact: contato@evolution-api.com │
├──────────────────────────────────────────────────────────────────────────────┤
│ @copyright © Evolution API 2025. All rights reserved. │
│ Licensed under the Apache License, Version 2.0 │
│ │
│ You may not use this file except in compliance with the License. │
│ You may obtain a copy of the License at │
│ │
│ http://www.apache.org/licenses/LICENSE-2.0 │
│ │
│ Unless required by applicable law or agreed to in writing, software │
│ distributed under the License is distributed on an "AS IS" BASIS, │
│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │
│ See the License for the specific language governing permissions and │
│ limitations under the License. │
├──────────────────────────────────────────────────────────────────────────────┤
│ @important │
│ For any future changes to the code in this file, it is recommended to │
│ include, together with the modification, the information of the developer │
│ who changed it and the date of modification. │
└──────────────────────────────────────────────────────────────────────────────┘
"""
from typing import List, Tuple, Optional
from src.schemas.schemas import Agent
from src.schemas.agent_config import AgentTask
from src.services.crewai.custom_tool import CustomToolBuilder
from src.services.crewai.mcp_service import MCPService
from src.utils.logger import setup_logger
from src.services.apikey_service import get_decrypted_api_key
from sqlalchemy.orm import Session
from contextlib import AsyncExitStack
from crewai import LLM, Agent as LlmAgent, Crew, Task, Process
from datetime import datetime
import uuid
logger = setup_logger(__name__)
class AgentBuilder:
def __init__(self, db: Session):
self.db = db
self.custom_tool_builder = CustomToolBuilder()
self.mcp_service = MCPService()
async def _get_api_key(self, agent: Agent) -> str:
"""Get the API key for the agent."""
api_key = None
# Get API key from api_key_id
if hasattr(agent, "api_key_id") and agent.api_key_id:
if decrypted_key := get_decrypted_api_key(self.db, agent.api_key_id):
logger.info(f"Using stored API key for agent {agent.name}")
api_key = decrypted_key
else:
logger.error(f"Stored API key not found for agent {agent.name}")
raise ValueError(
f"API key with ID {agent.api_key_id} not found or inactive"
)
else:
# Check if there is an API key in the config (temporary field)
config_api_key = agent.config.get("api_key") if agent.config else None
if config_api_key:
logger.info(f"Using config API key for agent {agent.name}")
# Check if it is a UUID of a stored key
try:
key_id = uuid.UUID(config_api_key)
if decrypted_key := get_decrypted_api_key(self.db, key_id):
logger.info("Config API key is a valid reference")
api_key = decrypted_key
else:
# Use the key directly
api_key = config_api_key
except (ValueError, TypeError):
# It is not a UUID, use directly
api_key = config_api_key
else:
logger.error(f"No API key configured for agent {agent.name}")
raise ValueError(
f"Agent {agent.name} does not have a configured API key"
)
return api_key
async def _create_llm(self, agent: Agent) -> LLM:
"""Create an LLM from the agent data."""
api_key = await self._get_api_key(agent)
return LLM(model=agent.model, api_key=api_key)
async def _create_llm_agent(
self, agent: Agent, enabled_tools: List[str] = []
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
"""Create an LLM agent from the agent data."""
# Get custom tools from the configuration
custom_tools = []
custom_tools = self.custom_tool_builder.build_tools(agent.config)
# # Get MCP tools from the configuration
mcp_tools = []
mcp_exit_stack = None
if agent.config.get("mcp_servers") or agent.config.get("custom_mcp_servers"):
try:
mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(
agent.config, self.db
)
except Exception as e:
logger.error(f"Error building MCP tools: {e}")
# Continue without MCP tools
mcp_tools = []
mcp_exit_stack = None
# # Get agent tools
# agent_tools = await self._agent_tools_builder(agent)
# Combine all tools
all_tools = custom_tools + mcp_tools
if enabled_tools:
all_tools = [tool for tool in all_tools if tool.name in enabled_tools]
logger.info(f"Enabled tools enabled. Total tools: {len(all_tools)}")
now = datetime.now()
current_datetime = now.strftime("%d/%m/%Y %H:%M")
current_day_of_week = now.strftime("%A")
current_date_iso = now.strftime("%Y-%m-%d")
current_time = now.strftime("%H:%M")
# Substitute variables in the prompt
formatted_prompt = agent.instruction.format(
current_datetime=current_datetime,
current_day_of_week=current_day_of_week,
current_date_iso=current_date_iso,
current_time=current_time,
)
llm_agent = LlmAgent(
role=agent.role,
goal=agent.goal,
backstory=formatted_prompt,
llm=await self._create_llm(agent),
tools=all_tools,
verbose=True,
cache=True,
# memory=True,
)
return llm_agent, mcp_exit_stack
async def _create_tasks(
self, agent: LlmAgent, tasks: List[AgentTask] = []
) -> List[Task]:
"""Create tasks from the agent data."""
tasks_list = []
if tasks:
tasks_list.extend(
Task(
name=task.name,
description=task.description,
expected_output=task.expected_output,
agent=agent,
verbose=True,
)
for task in tasks
)
return tasks_list
async def build_crew(self, agents: List[LlmAgent], tasks: List[Task] = []) -> Crew:
"""Create a crew from the agent data."""
return Crew(
agents=agents,
tasks=tasks,
process=Process.sequential,
verbose=True,
)
async def build_llm_agent(
self, root_agent, enabled_tools: List[str] = []
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
"""Build an LLM agent with its sub-agents."""
logger.info("Creating LLM agent")
try:
result = await self._create_llm_agent(root_agent, enabled_tools)
if isinstance(result, tuple) and len(result) == 2:
return result
else:
return result, None
except Exception as e:
logger.error(f"Error in build_llm_agent: {e}")
raise
async def build_agent(
self, root_agent, enabled_tools: List[str] = []
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
"""Build the appropriate agent based on the type of the root agent."""
if root_agent.type == "llm":
agent, exit_stack = await self.build_llm_agent(root_agent, enabled_tools)
return agent, exit_stack
elif root_agent.type == "a2a":
raise ValueError("A2A agents are not supported yet")
# return await self.build_a2a_agent(root_agent)
elif root_agent.type == "workflow":
raise ValueError("Workflow agents are not supported yet")
# return await self.build_workflow_agent(root_agent)
elif root_agent.type == "task":
raise ValueError("Task agents are not supported yet")
# return await self.build_task_agent(root_agent)
else:
raise ValueError(f"Invalid agent type: {root_agent.type}")
# return await self.build_composite_agent(root_agent)

View File

@@ -0,0 +1,595 @@
"""
┌──────────────────────────────────────────────────────────────────────────────┐
│ @author: Davidson Gomes │
│ @file: agent_runner.py │
│ Developed by: Davidson Gomes │
│ Creation date: May 13, 2025 │
│ Contact: contato@evolution-api.com │
├──────────────────────────────────────────────────────────────────────────────┤
│ @copyright © Evolution API 2025. All rights reserved. │
│ Licensed under the Apache License, Version 2.0 │
│ │
│ You may not use this file except in compliance with the License. │
│ You may obtain a copy of the License at │
│ │
│ http://www.apache.org/licenses/LICENSE-2.0 │
│ │
│ Unless required by applicable law or agreed to in writing, software │
│ distributed under the License is distributed on an "AS IS" BASIS, │
│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │
│ See the License for the specific language governing permissions and │
│ limitations under the License. │
├──────────────────────────────────────────────────────────────────────────────┤
│ @important │
│ For any future changes to the code in this file, it is recommended to │
│ include, together with the modification, the information of the developer │
│ who changed it and the date of modification. │
└──────────────────────────────────────────────────────────────────────────────┘
"""
from crewai import Crew, Task, Agent as LlmAgent
from src.services.crewai.session_service import (
CrewSessionService,
Event,
Content,
Part,
Session,
)
from src.services.crewai.agent_builder import AgentBuilder
from src.utils.logger import setup_logger
from src.core.exceptions import AgentNotFoundError, InternalServerError
from src.services.agent_service import get_agent
from sqlalchemy.orm import Session
from typing import Optional, AsyncGenerator
import asyncio
import json
from datetime import datetime
from src.utils.otel import get_tracer
from opentelemetry import trace
import base64
logger = setup_logger(__name__)
def extract_text_from_output(crew_output):
"""Extract text from CrewOutput object."""
if hasattr(crew_output, "raw") and crew_output.raw:
return crew_output.raw
elif hasattr(crew_output, "__str__"):
return str(crew_output)
# Fallback if no text found
return "Unable to extract a valid response."
async def run_agent(
agent_id: str,
external_id: str,
message: str,
session_service: CrewSessionService,
db: Session,
session_id: Optional[str] = None,
timeout: float = 60.0,
files: Optional[list] = None,
):
tracer = get_tracer()
with tracer.start_as_current_span(
"run_agent",
attributes={
"agent_id": agent_id,
"external_id": external_id,
"session_id": session_id or f"{external_id}_{agent_id}",
"message": message,
"has_files": files is not None and len(files) > 0,
},
):
exit_stack = None
try:
logger.info(
f"Starting execution of agent {agent_id} for external_id {external_id}"
)
logger.info(f"Received message: {message}")
if files and len(files) > 0:
logger.info(f"Received {len(files)} files with message")
get_root_agent = get_agent(db, agent_id)
logger.info(
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
)
if get_root_agent is None:
raise AgentNotFoundError(f"Agent with ID {agent_id} not found")
# Using the AgentBuilder to create the agent
agent_builder = AgentBuilder(db)
result = await agent_builder.build_agent(get_root_agent)
# Check how the result is structured
if isinstance(result, tuple) and len(result) == 2:
root_agent, exit_stack = result
else:
# If the result is not a tuple of 2 elements
root_agent = result
exit_stack = None
logger.warning("build_agent did not return an exit_stack")
# TODO: files should be processed here
# Fetch session information
crew_session_id = f"{external_id}_{agent_id}"
if session_id is None:
session_id = crew_session_id
logger.info(f"Searching session for external_id {external_id}")
try:
session = session_service.get_session(
agent_id=agent_id,
external_id=external_id,
session_id=crew_session_id,
)
except Exception as e:
logger.warning(f"Error getting session: {str(e)}")
session = None
if session is None:
logger.info(f"Creating new session for external_id {external_id}")
session = session_service.create_session(
agent_id=agent_id,
external_id=external_id,
session_id=crew_session_id,
)
# Add user message to session
session.events.append(
Event(
author="user",
content=Content(parts=[{"text": message}]),
timestamp=datetime.now().timestamp(),
)
)
# Save session to database
session_service.save_session(session)
# Build message history for context
conversation_history = []
if session and session.events:
for event in session.events:
if event.author and event.content and event.content.parts:
for part in event.content.parts:
if isinstance(part, dict) and "text" in part:
role = "User" if event.author == "user" else "Assistant"
conversation_history.append(f"{role}: {part['text']}")
# Build description with history as context
task_description = (
f"Conversation history:\n" + "\n".join(conversation_history)
if conversation_history
else ""
)
task_description += f"\n\nCurrent user message: {message}"
task = Task(
name="resolve_user_request",
description=task_description,
expected_output="Response to the user request",
agent=root_agent,
verbose=True,
)
crew = await agent_builder.build_crew([root_agent], [task])
# Use normal kickoff or kickoff_async instead of kickoff_for_each
if hasattr(crew, "kickoff_async"):
crew_output = await crew.kickoff_async(inputs={"message": message})
else:
loop = asyncio.get_event_loop()
crew_output = await loop.run_in_executor(
None, lambda: crew.kickoff(inputs={"message": message})
)
# Extract response and add to session
final_text = extract_text_from_output(crew_output)
# Add agent response as event in session
session.events.append(
Event(
author=get_root_agent.name,
content=Content(parts=[{"text": final_text}]),
timestamp=datetime.now().timestamp(),
)
)
# Save session with new event
session_service.save_session(session)
logger.info("Starting agent execution")
final_response_text = "No final response captured."
message_history = []
try:
response_queue = asyncio.Queue()
execution_completed = asyncio.Event()
async def process_events():
try:
# Log the result
logger.info(f"Crew output: {crew_output}")
# Signal that execution is complete
execution_completed.set()
# Extract text from CrewOutput object
final_text = "Unable to extract a valid response."
if hasattr(crew_output, "raw") and crew_output.raw:
final_text = crew_output.raw
elif hasattr(crew_output, "__str__"):
final_text = str(crew_output)
# If still empty or None, check crew artifacts
if not final_text or final_text.strip() == "":
# Try to get from agent messages
if hasattr(root_agent, "messages") and root_agent.messages:
# Get the last message from the agent
for msg in reversed(root_agent.messages):
if hasattr(msg, "content") and msg.content:
final_text = msg.content
break
# If still empty, use a fallback
if not final_text or final_text.strip() == "":
final_text = "The agent could not produce a valid response. Please try again with a different question."
# Put the extracted text in the queue
await response_queue.put(final_text)
except Exception as e:
logger.error(f"Error in process_events: {str(e)}")
# Provide a more helpful error response
error_response = f"An error occurred during processing: {str(e)}\n\nIf you are trying to use external tools such as Brave Search, please make sure the connection is working properly."
await response_queue.put(error_response)
execution_completed.set()
task = asyncio.create_task(process_events())
try:
wait_task = asyncio.create_task(execution_completed.wait())
done, pending = await asyncio.wait({wait_task}, timeout=timeout)
for p in pending:
p.cancel()
if not execution_completed.is_set():
logger.warning(
f"Agent execution timed out after {timeout} seconds"
)
await response_queue.put(
"The response took too long and was interrupted."
)
final_response_text = await response_queue.get()
except Exception as e:
logger.error(f"Error waiting for response: {str(e)}")
final_response_text = f"Error processing response: {str(e)}"
# Add the session to memory after completion
# completed_session = session_service.get_session(
# app_name=agent_id,
# user_id=external_id,
# session_id=crew_session_id,
# )
# memory_service.add_session_to_memory(completed_session)
# Cancel the processing task if it is still running
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
logger.info("Task cancelled successfully")
except Exception as e:
logger.error(f"Error cancelling task: {str(e)}")
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
raise InternalServerError(str(e)) from e
logger.info("Agent execution completed successfully")
return {
"final_response": final_response_text,
"message_history": message_history,
}
except AgentNotFoundError as e:
logger.error(f"Error processing request: {str(e)}")
raise e
except Exception as e:
logger.error(f"Internal error processing request: {str(e)}", exc_info=True)
raise InternalServerError(str(e))
finally:
# Clean up MCP connection - MUST be executed in the same task
if exit_stack:
logger.info("Closing MCP server connection...")
try:
if hasattr(exit_stack, "aclose"):
# If it's an AsyncExitStack
await exit_stack.aclose()
elif isinstance(exit_stack, list):
# If it's a list of adapters
for adapter in exit_stack:
if hasattr(adapter, "close"):
adapter.close()
except Exception as e:
logger.error(f"Error closing MCP connection: {e}")
# Do not raise the exception to not obscure the original error
def convert_sets(obj):
if isinstance(obj, set):
return list(obj)
elif isinstance(obj, dict):
return {k: convert_sets(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_sets(i) for i in obj]
else:
return obj
async def run_agent_stream(
agent_id: str,
external_id: str,
message: str,
db: Session,
session_id: Optional[str] = None,
files: Optional[list] = None,
) -> AsyncGenerator[str, None]:
tracer = get_tracer()
span = tracer.start_span(
"run_agent_stream",
attributes={
"agent_id": agent_id,
"external_id": external_id,
"session_id": session_id or f"{external_id}_{agent_id}",
"message": message,
"has_files": files is not None and len(files) > 0,
},
)
exit_stack = None
try:
with trace.use_span(span, end_on_exit=True):
try:
logger.info(
f"Starting streaming execution of agent {agent_id} for external_id {external_id}"
)
logger.info(f"Received message: {message}")
if files and len(files) > 0:
logger.info(f"Received {len(files)} files with message")
get_root_agent = get_agent(db, agent_id)
logger.info(
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
)
if get_root_agent is None:
raise AgentNotFoundError(f"Agent with ID {agent_id} not found")
# Using the AgentBuilder to create the agent
agent_builder = AgentBuilder(db)
result = await agent_builder.build_agent(get_root_agent)
# Check how the result is structured
if isinstance(result, tuple) and len(result) == 2:
root_agent, exit_stack = result
else:
# If the result is not a tuple of 2 elements
root_agent = result
exit_stack = None
logger.warning("build_agent did not return an exit_stack")
# TODO: files should be processed here
# Fetch session history if available
session_id = f"{external_id}_{agent_id}"
# Create an instance of the session service
try:
from src.config.settings import get_settings
settings = get_settings()
db_url = settings.DATABASE_URL
except ImportError:
# Fallback to local SQLite if cannot import settings
db_url = "sqlite:///data/crew_sessions.db"
session_service = CrewSessionService(db_url)
try:
# Try to get existing session
session = session_service.get_session(
agent_id=agent_id,
external_id=external_id,
session_id=session_id,
)
except Exception as e:
logger.warning(f"Could not load session: {e}")
session = None
# Build message history for context
conversation_history = []
if session and session.events:
for event in session.events:
if event.author and event.content and event.content.parts:
for part in event.content.parts:
if isinstance(part, dict) and "text" in part:
role = (
"User"
if event.author == "user"
else "Assistant"
)
conversation_history.append(
f"{role}: {part['text']}"
)
# Build description with history
task_description = (
f"Conversation history:\n" + "\n".join(conversation_history)
if conversation_history
else ""
)
task_description += f"\n\nCurrent user message: {message}"
task = Task(
name="resolve_user_request",
description=task_description,
expected_output="Response to the user request",
agent=root_agent,
verbose=True,
)
crew = await agent_builder.build_crew([root_agent], [task])
logger.info("Starting agent streaming execution")
try:
# Check if we can process messages with kickoff_for_each
if hasattr(crew, "kickoff_for_each"):
# Create input with current message
inputs = [{"message": message}]
logger.info(
f"Using kickoff_for_each for streaming with {len(inputs)} input(s)"
)
# Execute kickoff_for_each
results = crew.kickoff_for_each(inputs=inputs)
# Print results and save to session
for i, result in enumerate(results):
logger.info(f"Result of event {i+1}: {result}")
# If we have a session, save the response to it
if session:
# Add agent response as event
session.events.append(
Event(
author="agent",
content=Content(parts=[{"text": result}]),
timestamp=datetime.now().timestamp(),
)
)
# Save current session with new message
if session:
# Also add user message if it doesn't exist yet
if not any(
e.author == "user"
and any(
p.get("text") == message for p in e.content.parts
)
for e in session.events
if e.content and e.content.parts
):
session.events.append(
Event(
author="user",
content=Content(parts=[{"text": message}]),
timestamp=datetime.now().timestamp(),
)
)
# Save session
try:
session_service.save_session(session)
logger.info(f"Session saved successfully: {session_id}")
except Exception as e:
logger.error(f"Error saving session: {e}")
# Use last result as final output
crew_output = results[-1] if results else None
else:
# CrewAI kickoff method is synchronous, fallback if kickoff_for_each not available
logger.info(
"kickoff_for_each not available, using standard kickoff for streaming"
)
crew_output = crew.kickoff()
logger.info(f"Crew output: {crew_output}")
# Extract the actual text content
if hasattr(crew_output, "raw") and crew_output.raw:
final_output = crew_output.raw
elif hasattr(crew_output, "__str__"):
final_output = str(crew_output)
else:
final_output = "Could not extract text from response"
# Save response to session (for fallback case of normal kickoff)
if session and not hasattr(crew, "kickoff_for_each"):
# Add agent response
session.events.append(
Event(
author="agent",
content=Content(parts=[{"text": final_output}]),
timestamp=datetime.now().timestamp(),
)
)
# Add user message if it doesn't exist yet
if not any(
e.author == "user"
and any(p.get("text") == message for p in e.content.parts)
for e in session.events
if e.content and e.content.parts
):
session.events.append(
Event(
author="user",
content=Content(parts=[{"text": message}]),
timestamp=datetime.now().timestamp(),
)
)
# Save session
try:
session_service.save_session(session)
logger.info(
f"Session saved successfully (method: kickoff): {session_id}"
)
except Exception as e:
logger.error(f"Error saving session: {e}")
yield json.dumps({"text": final_output})
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
raise InternalServerError(str(e)) from e
finally:
# Clean up MCP connection
if exit_stack:
logger.info("Closing MCP server connection...")
try:
if hasattr(exit_stack, "aclose"):
# If it's an AsyncExitStack
await exit_stack.aclose()
elif isinstance(exit_stack, list):
# If it's a list of adapters
for adapter in exit_stack:
if hasattr(adapter, "close"):
adapter.close()
except Exception as e:
logger.error(f"Error closing MCP connection: {e}")
# Do not raise the exception to not obscure the original error
logger.info("Agent streaming execution completed successfully")
except AgentNotFoundError as e:
logger.error(f"Error processing request: {str(e)}")
raise InternalServerError(str(e)) from e
except Exception as e:
logger.error(
f"Internal error processing request: {str(e)}", exc_info=True
)
raise InternalServerError(str(e))
finally:
span.end()

View File

@@ -0,0 +1,369 @@
"""
┌──────────────────────────────────────────────────────────────────────────────┐
│ @author: Davidson Gomes │
│ @file: custom_tool.py │
│ Developed by: Davidson Gomes │
│ Creation date: May 13, 2025 │
│ Contact: contato@evolution-api.com │
├──────────────────────────────────────────────────────────────────────────────┤
│ @copyright © Evolution API 2025. All rights reserved. │
│ Licensed under the Apache License, Version 2.0 │
│ │
│ You may not use this file except in compliance with the License. │
│ You may obtain a copy of the License at │
│ │
│ http://www.apache.org/licenses/LICENSE-2.0 │
│ │
│ Unless required by applicable law or agreed to in writing, software │
│ distributed under the License is distributed on an "AS IS" BASIS, │
│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │
│ See the License for the specific language governing permissions and │
│ limitations under the License. │
├──────────────────────────────────────────────────────────────────────────────┤
│ @important │
│ For any future changes to the code in this file, it is recommended to │
│ include, together with the modification, the information of the developer │
│ who changed it and the date of modification. │
└──────────────────────────────────────────────────────────────────────────────┘
"""
from typing import Any, Dict, List, Type
from crewai.tools import BaseTool, tool
import requests
import json
from src.utils.logger import setup_logger
from pydantic import BaseModel, Field, create_model
logger = setup_logger(__name__)
class CustomToolBuilder:
def __init__(self):
self.tools = []
def _create_http_tool(self, tool_config: Dict[str, Any]) -> BaseTool:
"""Create an HTTP tool based on the provided configuration."""
# Extract configuration parameters
name = tool_config["name"]
description = tool_config["description"]
endpoint = tool_config["endpoint"]
method = tool_config["method"]
headers = tool_config.get("headers", {})
parameters = tool_config.get("parameters", {}) or {}
values = tool_config.get("values", {})
error_handling = tool_config.get("error_handling", {})
path_params = parameters.get("path_params") or {}
query_params = parameters.get("query_params") or {}
body_params = parameters.get("body_params") or {}
# Dynamic creation of the input schema for the tool
field_definitions = {}
# Add all parameters as fields
for param in (
list(path_params.keys())
+ list(query_params.keys())
+ list(body_params.keys())
):
# Default to string type for all parameters
field_definitions[param] = (
str,
Field(..., description=f"Parameter {param}"),
)
# If there are no parameters but default values, use those as optional fields
if not field_definitions and values:
for param, value in values.items():
param_type = type(value)
field_definitions[param] = (
param_type,
Field(default=value, description=f"Parameter {param}"),
)
# Create dynamic input schema model in line with the documentation
tool_input_model = create_model(
f"{name.replace(' ', '')}Input", **field_definitions
)
# Create the HTTP tool using crewai's BaseTool class
# Following the pattern in the documentation
def create_http_tool_class():
# Capture variables from outer scope
_name = name
_description = description
_tool_input_model = tool_input_model
class HttpTool(BaseTool):
name: str = _name
description: str = _description
args_schema: Type[BaseModel] = _tool_input_model
def _run(self, **kwargs):
"""Execute the HTTP request and return the result."""
try:
# Combines default values with provided values
all_values = {**values, **kwargs}
# Substitutes placeholders in headers
processed_headers = {
k: v.format(**all_values) if isinstance(v, str) else v
for k, v in headers.items()
}
# Processes path parameters
url = endpoint
for param, value in path_params.items():
if param in all_values:
url = url.replace(
f"{{{param}}}", str(all_values[param])
)
# Process query parameters
query_params_dict = {}
for param, value in query_params.items():
if isinstance(value, list):
# If the value is a list, join with comma
query_params_dict[param] = ",".join(value)
elif param in all_values:
# If the parameter is in the values, use the provided value
query_params_dict[param] = all_values[param]
else:
# Otherwise, use the default value from the configuration
query_params_dict[param] = value
# Adds default values to query params if they are not present
for param, value in values.items():
if (
param not in query_params_dict
and param not in path_params
):
query_params_dict[param] = value
body_data = {}
for param, param_config in body_params.items():
if param in all_values:
body_data[param] = all_values[param]
# Adds default values to body if they are not present
for param, value in values.items():
if (
param not in body_data
and param not in query_params_dict
and param not in path_params
):
body_data[param] = value
# Makes the HTTP request
response = requests.request(
method=method,
url=url,
headers=processed_headers,
params=query_params_dict,
json=body_data or None,
timeout=error_handling.get("timeout", 30),
)
if response.status_code >= 400:
raise requests.exceptions.HTTPError(
f"Error in the request: {response.status_code} - {response.text}"
)
# Always returns the response as a string
return json.dumps(response.json())
except Exception as e:
logger.error(f"Error executing tool {name}: {str(e)}")
return json.dumps(
error_handling.get(
"fallback_response",
{"error": "tool_execution_error", "message": str(e)},
)
)
return HttpTool
# Create the tool instance
HttpToolClass = create_http_tool_class()
http_tool = HttpToolClass()
# Add cache function following the documentation
def http_cache_function(arguments: dict, result: str) -> bool:
"""Determines whether to cache the result based on arguments and result."""
# Default implementation: cache all successful results
try:
# If the result is parseable JSON and not an error, cache it
result_obj = json.loads(result)
return not (isinstance(result_obj, dict) and "error" in result_obj)
except Exception:
# If result is not valid JSON, don't cache
return False
# Assign the cache function to the tool
http_tool.cache_function = http_cache_function
return http_tool
def _create_http_tool_with_decorator(self, tool_config: Dict[str, Any]) -> Any:
"""Create an HTTP tool using the tool decorator."""
# Extract configuration parameters
name = tool_config["name"]
description = tool_config["description"]
endpoint = tool_config["endpoint"]
method = tool_config["method"]
headers = tool_config.get("headers", {})
parameters = tool_config.get("parameters", {}) or {}
values = tool_config.get("values", {})
error_handling = tool_config.get("error_handling", {})
path_params = parameters.get("path_params") or {}
query_params = parameters.get("query_params") or {}
body_params = parameters.get("body_params") or {}
# Create function docstring with parameter documentation
param_list = (
list(path_params.keys())
+ list(query_params.keys())
+ list(body_params.keys())
)
doc_params = []
for param in param_list:
doc_params.append(f" {param}: Parameter description")
docstring = (
f"{description}\n\nParameters:\n"
+ "\n".join(doc_params)
+ "\n\nReturns:\n String containing the response in JSON format"
)
# Create the tool function using the decorator pattern in the documentation
@tool(name=name)
def http_tool(**kwargs):
"""Tool function created dynamically."""
try:
# Combines default values with provided values
all_values = {**values, **kwargs}
# Substitutes placeholders in headers
processed_headers = {
k: v.format(**all_values) if isinstance(v, str) else v
for k, v in headers.items()
}
# Processes path parameters
url = endpoint
for param, value in path_params.items():
if param in all_values:
url = url.replace(f"{{{param}}}", str(all_values[param]))
# Process query parameters
query_params_dict = {}
for param, value in query_params.items():
if isinstance(value, list):
# If the value is a list, join with comma
query_params_dict[param] = ",".join(value)
elif param in all_values:
# If the parameter is in the values, use the provided value
query_params_dict[param] = all_values[param]
else:
# Otherwise, use the default value from the configuration
query_params_dict[param] = value
# Adds default values to query params if they are not present
for param, value in values.items():
if param not in query_params_dict and param not in path_params:
query_params_dict[param] = value
body_data = {}
for param, param_config in body_params.items():
if param in all_values:
body_data[param] = all_values[param]
# Adds default values to body if they are not present
for param, value in values.items():
if (
param not in body_data
and param not in query_params_dict
and param not in path_params
):
body_data[param] = value
# Makes the HTTP request
response = requests.request(
method=method,
url=url,
headers=processed_headers,
params=query_params_dict,
json=body_data or None,
timeout=error_handling.get("timeout", 30),
)
if response.status_code >= 400:
raise requests.exceptions.HTTPError(
f"Error in the request: {response.status_code} - {response.text}"
)
# Always returns the response as a string
return json.dumps(response.json())
except Exception as e:
logger.error(f"Error executing tool {name}: {str(e)}")
return json.dumps(
error_handling.get(
"fallback_response",
{"error": "tool_execution_error", "message": str(e)},
)
)
# Replace the docstring
http_tool.__doc__ = docstring
# Add cache function following the documentation
def http_cache_function(arguments: dict, result: str) -> bool:
"""Determines whether to cache the result based on arguments and result."""
# Default implementation: cache all successful results
try:
# If the result is parseable JSON and not an error, cache it
result_obj = json.loads(result)
return not (isinstance(result_obj, dict) and "error" in result_obj)
except Exception:
# If result is not valid JSON, don't cache
return False
# Assign the cache function to the tool
http_tool.cache_function = http_cache_function
return http_tool
def build_tools(self, tools_config: Dict[str, Any]) -> List[BaseTool]:
"""Builds a list of tools based on the provided configuration. Accepts both 'tools' and 'custom_tools' (with http_tools)."""
self.tools = []
# Find HTTP tools configuration in various possible locations
http_tools = []
if tools_config.get("http_tools"):
http_tools = tools_config.get("http_tools", [])
elif tools_config.get("custom_tools") and tools_config["custom_tools"].get(
"http_tools"
):
http_tools = tools_config["custom_tools"].get("http_tools", [])
elif (
tools_config.get("tools")
and isinstance(tools_config["tools"], dict)
and tools_config["tools"].get("http_tools")
):
http_tools = tools_config["tools"].get("http_tools", [])
# Determine which implementation method to use (BaseTool or decorator)
use_decorator = tools_config.get("use_decorator", False)
# Create tools for each HTTP tool configuration
for http_tool_config in http_tools:
if use_decorator:
self.tools.append(
self._create_http_tool_with_decorator(http_tool_config)
)
else:
self.tools.append(self._create_http_tool(http_tool_config))
return self.tools

View File

@@ -0,0 +1,264 @@
"""
┌──────────────────────────────────────────────────────────────────────────────┐
│ @author: Davidson Gomes │
│ @file: mcp_service.py │
│ Developed by: Davidson Gomes │
│ Creation date: May 13, 2025 │
│ Contact: contato@evolution-api.com │
├──────────────────────────────────────────────────────────────────────────────┤
│ @copyright © Evolution API 2025. All rights reserved. │
│ Licensed under the Apache License, Version 2.0 │
│ │
│ You may not use this file except in compliance with the License. │
│ You may obtain a copy of the License at │
│ │
│ http://www.apache.org/licenses/LICENSE-2.0 │
│ │
│ Unless required by applicable law or agreed to in writing, software │
│ distributed under the License is distributed on an "AS IS" BASIS, │
│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │
│ See the License for the specific language governing permissions and │
│ limitations under the License. │
├──────────────────────────────────────────────────────────────────────────────┤
│ @important │
│ For any future changes to the code in this file, it is recommended to │
│ include, together with the modification, the information of the developer │
│ who changed it and the date of modification. │
└──────────────────────────────────────────────────────────────────────────────┘
"""
from typing import Any, Dict, List, Optional, Tuple
from contextlib import ExitStack
import os
import sys
from src.utils.logger import setup_logger
from src.services.mcp_server_service import get_mcp_server
from sqlalchemy.orm import Session
try:
from crewai_tools import MCPServerAdapter
from mcp import StdioServerParameters
HAS_MCP_PACKAGES = True
except ImportError:
logger = setup_logger(__name__)
logger.error(
"MCP packages are not installed. Please install mcp and crewai-tools[mcp]"
)
HAS_MCP_PACKAGES = False
logger = setup_logger(__name__)
class MCPService:
def __init__(self):
self.tools = []
self.exit_stack = ExitStack()
def _connect_to_mcp_server(
self, server_config: Dict[str, Any]
) -> Tuple[List[Any], Optional[ExitStack]]:
"""Connect to a specific MCP server and return its tools."""
if not HAS_MCP_PACKAGES:
logger.error("Cannot connect to MCP server: MCP packages not installed")
return [], None
try:
# Determines the type of server (local or remote)
if "url" in server_config:
# Remote server (SSE) - Simplified approach using direct dictionary
sse_config = {"url": server_config["url"]}
# Add headers if provided
if "headers" in server_config and server_config["headers"]:
sse_config["headers"] = server_config["headers"]
# Create the MCPServerAdapter with the SSE configuration
mcp_adapter = MCPServerAdapter(sse_config)
else:
# Local server (Stdio)
command = server_config.get("command", "npx")
args = server_config.get("args", [])
# Adds environment variables if specified
env = server_config.get("env", {})
if env:
for key, value in env.items():
os.environ[key] = value
connection_params = StdioServerParameters(
command=command, args=args, env=env
)
# Create the MCPServerAdapter with the Stdio connection parameters
mcp_adapter = MCPServerAdapter(connection_params)
# Get tools from the adapter
tools = mcp_adapter.tools
# Return tools and the adapter (which serves as an exit stack)
return tools, mcp_adapter
except Exception as e:
logger.error(f"Error connecting to MCP server: {e}")
return [], None
def _filter_incompatible_tools(self, tools: List[Any]) -> List[Any]:
"""Filters incompatible tools with the model."""
problematic_tools = [
"create_pull_request_review", # This tool causes the 400 INVALID_ARGUMENT error
]
filtered_tools = []
removed_count = 0
for tool in tools:
if tool.name in problematic_tools:
logger.warning(f"Removing incompatible tool: {tool.name}")
removed_count += 1
else:
filtered_tools.append(tool)
if removed_count > 0:
logger.warning(f"Removed {removed_count} incompatible tools.")
return filtered_tools
def _filter_tools_by_agent(
self, tools: List[Any], agent_tools: List[str]
) -> List[Any]:
"""Filters tools compatible with the agent."""
if not agent_tools:
return tools
filtered_tools = []
for tool in tools:
logger.info(f"Tool: {tool.name}")
if tool.name in agent_tools:
filtered_tools.append(tool)
return filtered_tools
async def build_tools(
self, mcp_config: Dict[str, Any], db: Session
) -> Tuple[List[Any], Any]:
"""Builds a list of tools from multiple MCP servers."""
if not HAS_MCP_PACKAGES:
logger.error("Cannot build MCP tools: MCP packages not installed")
return [], None
self.tools = []
self.exit_stack = ExitStack()
adapter_list = []
try:
mcp_servers = mcp_config.get("mcp_servers", [])
if mcp_servers is not None:
# Process each MCP server in the configuration
for server in mcp_servers:
try:
# Search for the MCP server in the database
mcp_server = get_mcp_server(db, server["id"])
if not mcp_server:
logger.warning(f"MCP Server not found: {server['id']}")
continue
# Prepares the server configuration
server_config = mcp_server.config_json.copy()
# Replaces the environment variables in the config_json
if "env" in server_config and server_config["env"] is not None:
for key, value in server_config["env"].items():
if value and value.startswith("env@@"):
env_key = value.replace("env@@", "")
if server.get("envs") and env_key in server.get(
"envs", {}
):
server_config["env"][key] = server["envs"][
env_key
]
else:
logger.warning(
f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}"
)
continue
logger.info(f"Connecting to MCP server: {mcp_server.name}")
tools, adapter = self._connect_to_mcp_server(server_config)
if tools and adapter:
# Filters incompatible tools
filtered_tools = self._filter_incompatible_tools(tools)
# Filters tools compatible with the agent
if agent_tools := server.get("tools", []):
filtered_tools = self._filter_tools_by_agent(
filtered_tools, agent_tools
)
self.tools.extend(filtered_tools)
# Add to the adapter list for cleanup later
adapter_list.append(adapter)
logger.info(
f"MCP Server {mcp_server.name} connected successfully. Added {len(filtered_tools)} tools."
)
else:
logger.warning(
f"Failed to connect or no tools available for {mcp_server.name}"
)
except Exception as e:
logger.error(
f"Error connecting to MCP server {server.get('id', 'unknown')}: {e}"
)
continue
custom_mcp_servers = mcp_config.get("custom_mcp_servers", [])
if custom_mcp_servers is not None:
# Process custom MCP servers
for server in custom_mcp_servers:
if not server:
logger.warning(
"Empty server configuration found in custom_mcp_servers"
)
continue
try:
logger.info(
f"Connecting to custom MCP server: {server.get('url', 'unknown')}"
)
tools, adapter = self._connect_to_mcp_server(server)
if tools:
self.tools.extend(tools)
else:
logger.warning("No tools returned from custom MCP server")
continue
if adapter:
adapter_list.append(adapter)
logger.info(
f"Custom MCP server connected successfully. Added {len(tools)} tools."
)
else:
logger.warning("No adapter returned from custom MCP server")
except Exception as e:
logger.error(
f"Error connecting to custom MCP server {server.get('url', 'unknown')}: {e}"
)
continue
logger.info(
f"MCP Toolset created successfully. Total of {len(self.tools)} tools."
)
except Exception as e:
# Ensure cleanup
for adapter in adapter_list:
if hasattr(adapter, "close"):
adapter.close()
logger.error(f"Fatal error connecting to MCP servers: {e}")
# Return empty lists in case of error
return [], None
# Return the tools and the adapter list for cleanup
return self.tools, adapter_list

View File

@@ -0,0 +1,637 @@
"""
┌──────────────────────────────────────────────────────────────────────────────┐
│ @author: Davidson Gomes │
│ @file: session_service.py │
│ Developed by: Davidson Gomes │
│ Creation date: May 13, 2025 │
│ Contact: contato@evolution-api.com │
├──────────────────────────────────────────────────────────────────────────────┤
│ @copyright © Evolution API 2025. All rights reserved. │
│ Licensed under the Apache License, Version 2.0 │
│ │
│ You may not use this file except in compliance with the License. │
│ You may obtain a copy of the License at │
│ │
│ http://www.apache.org/licenses/LICENSE-2.0 │
│ │
│ Unless required by applicable law or agreed to in writing, software │
│ distributed under the License is distributed on an "AS IS" BASIS, │
│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │
│ See the License for the specific language governing permissions and │
│ limitations under the License. │
├──────────────────────────────────────────────────────────────────────────────┤
│ @important │
│ For any future changes to the code in this file, it is recommended to │
│ include, together with the modification, the information of the developer │
│ who changed it and the date of modification. │
└──────────────────────────────────────────────────────────────────────────────┘
"""
from datetime import datetime
import json
import uuid
import base64
import copy
from typing import Any, Dict, List, Optional, Union, Set
from sqlalchemy import create_engine, Boolean, Text, ForeignKeyConstraint
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import (
sessionmaker,
relationship,
DeclarativeBase,
Mapped,
mapped_column,
)
from sqlalchemy.sql import func
from sqlalchemy.types import DateTime, PickleType, String
from sqlalchemy.dialects import postgresql
from sqlalchemy.types import TypeDecorator
from pydantic import BaseModel
from src.utils.logger import setup_logger
logger = setup_logger(__name__)
class DynamicJSON(TypeDecorator):
"""JSON type compatible with ADK that uses JSONB in PostgreSQL and TEXT with JSON
serialization for other databases."""
impl = Text # Default implementation is TEXT
def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(postgresql.JSONB)
else:
return dialect.type_descriptor(Text)
def process_bind_param(self, value, dialect):
if value is not None:
if dialect.name == "postgresql":
return value
else:
return json.dumps(value)
return value
def process_result_value(self, value, dialect):
if value is not None:
if dialect.name == "postgresql":
return value
else:
return json.loads(value)
return value
class Base(DeclarativeBase):
"""Base class for database tables."""
pass
class StorageSession(Base):
"""Represents a session stored in the database, compatible with ADK."""
__tablename__ = "sessions"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
id: Mapped[str] = mapped_column(
String, primary_key=True, default=lambda: str(uuid.uuid4())
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
create_time: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
update_time: Mapped[DateTime] = mapped_column(
DateTime(), default=func.now(), onupdate=func.now()
)
storage_events: Mapped[list["StorageEvent"]] = relationship(
"StorageEvent",
back_populates="storage_session",
)
def __repr__(self):
return f"<StorageSession(id={self.id}, update_time={self.update_time})>"
class StorageEvent(Base):
"""Represents an event stored in the database, compatible with ADK."""
__tablename__ = "events"
id: Mapped[str] = mapped_column(String, primary_key=True)
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
session_id: Mapped[str] = mapped_column(String, primary_key=True)
invocation_id: Mapped[str] = mapped_column(String)
author: Mapped[str] = mapped_column(String)
branch: Mapped[str] = mapped_column(String, nullable=True)
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
Text, nullable=True
)
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(
DynamicJSON, nullable=True
)
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
error_code: Mapped[str] = mapped_column(String, nullable=True)
error_message: Mapped[str] = mapped_column(String, nullable=True)
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
storage_session: Mapped[StorageSession] = relationship(
"StorageSession",
back_populates="storage_events",
)
__table_args__ = (
ForeignKeyConstraint(
["app_name", "user_id", "session_id"],
["sessions.app_name", "sessions.user_id", "sessions.id"],
ondelete="CASCADE",
),
)
@property
def long_running_tool_ids(self) -> set[str]:
return (
set(json.loads(self.long_running_tool_ids_json))
if self.long_running_tool_ids_json
else set()
)
@long_running_tool_ids.setter
def long_running_tool_ids(self, value: set[str]):
if value is None:
self.long_running_tool_ids_json = None
else:
self.long_running_tool_ids_json = json.dumps(list(value))
class StorageAppState(Base):
"""Represents an application state stored in the database, compatible with ADK."""
__tablename__ = "app_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
update_time: Mapped[DateTime] = mapped_column(
DateTime(), default=func.now(), onupdate=func.now()
)
class StorageUserState(Base):
"""Represents a user state stored in the database, compatible with ADK."""
__tablename__ = "user_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
update_time: Mapped[DateTime] = mapped_column(
DateTime(), default=func.now(), onupdate=func.now()
)
# Pydantic model classes compatible with ADK
class State:
"""Utility class for states, compatible with ADK."""
APP_PREFIX = "app:"
USER_PREFIX = "user:"
TEMP_PREFIX = "temp:"
class Content(BaseModel):
"""Event content model, compatible with ADK."""
parts: List[Dict[str, Any]]
class Part(BaseModel):
"""Content part model, compatible with ADK."""
text: Optional[str] = None
class Event(BaseModel):
"""Event model, compatible with ADK."""
id: Optional[str] = None
author: str
branch: Optional[str] = None
invocation_id: Optional[str] = None
content: Optional[Content] = None
actions: Optional[Dict[str, Any]] = None
timestamp: Optional[float] = None
long_running_tool_ids: Optional[Set[str]] = None
grounding_metadata: Optional[Dict[str, Any]] = None
partial: Optional[bool] = None
turn_complete: Optional[bool] = None
error_code: Optional[str] = None
error_message: Optional[str] = None
interrupted: Optional[bool] = None
class Session(BaseModel):
"""Session model, compatible with ADK."""
app_name: str
user_id: str
id: str
state: Dict[str, Any] = {}
events: List[Event] = []
last_update_time: float
class Config:
arbitrary_types_allowed = True
class CrewSessionService:
"""Service for managing CrewAI agent sessions using ADK tables."""
def __init__(self, db_url: str):
"""
Initializes the session service.
Args:
db_url: Database connection URL.
"""
try:
self.engine = create_engine(db_url)
except Exception as e:
raise ValueError(f"Failed to create database engine: {e}")
# Create all tables
Base.metadata.create_all(self.engine)
self.Session = sessionmaker(bind=self.engine)
logger.info(f"CrewSessionService started with database at {db_url}")
def create_session(
self, agent_id: str, external_id: str, session_id: Optional[str] = None
) -> Session:
"""
Creates a new session for an agent.
Args:
agent_id: Agent ID (used as app_name in ADK)
external_id: External ID (used as user_id in ADK)
session_id: Optional session ID
Returns:
Session: The created session
"""
session_id = session_id or str(uuid.uuid4())
with self.Session() as db_session:
# Check if app and user states already exist
storage_app_state = db_session.get(StorageAppState, (agent_id))
storage_user_state = db_session.get(
StorageUserState, (agent_id, external_id)
)
app_state = storage_app_state.state if storage_app_state else {}
user_state = storage_user_state.state if storage_user_state else {}
# Create states if they don't exist
if not storage_app_state:
storage_app_state = StorageAppState(app_name=agent_id, state={})
db_session.add(storage_app_state)
if not storage_user_state:
storage_user_state = StorageUserState(
app_name=agent_id, user_id=external_id, state={}
)
db_session.add(storage_user_state)
# Create session
storage_session = StorageSession(
app_name=agent_id,
user_id=external_id,
id=session_id,
state={},
)
db_session.add(storage_session)
db_session.commit()
# Get timestamp
db_session.refresh(storage_session)
# Merge states for response
merged_state = _merge_state(app_state, user_state, {})
# Create Session object for return
session = Session(
app_name=agent_id,
user_id=external_id,
id=session_id,
state=merged_state,
last_update_time=storage_session.update_time.timestamp(),
)
logger.info(
f"Session created: {session_id} for agent {agent_id} and user {external_id}"
)
return session
def get_session(
self, agent_id: str, external_id: str, session_id: str
) -> Optional[Session]:
"""
Retrieves a session from the database.
Args:
agent_id: Agent ID
external_id: User ID
session_id: Session ID
Returns:
Optional[Session]: The retrieved session or None if not found
"""
with self.Session() as db_session:
storage_session = db_session.get(
StorageSession, (agent_id, external_id, session_id)
)
if storage_session is None:
return None
# Fetch session events
storage_events = (
db_session.query(StorageEvent)
.filter(StorageEvent.session_id == storage_session.id)
.filter(StorageEvent.app_name == agent_id)
.filter(StorageEvent.user_id == external_id)
.all()
)
# Fetch states
storage_app_state = db_session.get(StorageAppState, (agent_id))
storage_user_state = db_session.get(
StorageUserState, (agent_id, external_id)
)
app_state = storage_app_state.state if storage_app_state else {}
user_state = storage_user_state.state if storage_user_state else {}
session_state = storage_session.state
# Merge states
merged_state = _merge_state(app_state, user_state, session_state)
# Create session
session = Session(
app_name=agent_id,
user_id=external_id,
id=session_id,
state=merged_state,
last_update_time=storage_session.update_time.timestamp(),
)
# Add events
session.events = [
Event(
id=e.id,
author=e.author,
branch=e.branch,
invocation_id=e.invocation_id,
content=_decode_content(e.content),
actions=e.actions,
timestamp=e.timestamp.timestamp(),
long_running_tool_ids=e.long_running_tool_ids,
grounding_metadata=e.grounding_metadata,
partial=e.partial,
turn_complete=e.turn_complete,
error_code=e.error_code,
error_message=e.error_message,
interrupted=e.interrupted,
)
for e in storage_events
]
return session
def save_session(self, session: Session) -> None:
"""
Saves a session to the database.
Args:
session: The session to save
"""
with self.Session() as db_session:
storage_session = db_session.get(
StorageSession, (session.app_name, session.user_id, session.id)
)
if not storage_session:
logger.error(f"Session not found: {session.id}")
return
# Check states
storage_app_state = db_session.get(StorageAppState, (session.app_name))
storage_user_state = db_session.get(
StorageUserState, (session.app_name, session.user_id)
)
# Extract state deltas
app_state_delta = {}
user_state_delta = {}
session_state_delta = {}
# Apply state deltas
if storage_app_state and app_state_delta:
storage_app_state.state.update(app_state_delta)
if storage_user_state and user_state_delta:
storage_user_state.state.update(user_state_delta)
storage_session.state.update(session_state_delta)
# Save new events
for event in session.events:
# Check if event already exists
existing_event = (
(
db_session.query(StorageEvent)
.filter(StorageEvent.id == event.id)
.filter(StorageEvent.app_name == session.app_name)
.filter(StorageEvent.user_id == session.user_id)
.filter(StorageEvent.session_id == session.id)
.first()
)
if event.id
else None
)
if existing_event:
continue
# Generate ID for the event if it doesn't exist
if not event.id:
event.id = str(uuid.uuid4())
# Create timestamp if it doesn't exist
if not event.timestamp:
event.timestamp = datetime.now().timestamp()
# Create StorageEvent object
storage_event = StorageEvent(
id=event.id,
app_name=session.app_name,
user_id=session.user_id,
session_id=session.id,
invocation_id=event.invocation_id or str(uuid.uuid4()),
author=event.author,
branch=event.branch,
timestamp=datetime.fromtimestamp(event.timestamp),
actions=event.actions or {},
long_running_tool_ids=event.long_running_tool_ids or set(),
grounding_metadata=event.grounding_metadata,
partial=event.partial,
turn_complete=event.turn_complete,
error_code=event.error_code,
error_message=event.error_message,
interrupted=event.interrupted,
)
# Encode content, if it exists
if event.content:
encoded_content = event.content.model_dump(exclude_none=True)
# Solution for serialization issues with multimedia content
for p in encoded_content.get("parts", []):
if "inline_data" in p:
p["inline_data"]["data"] = (
base64.b64encode(p["inline_data"]["data"]).decode(
"utf-8"
),
)
storage_event.content = encoded_content
db_session.add(storage_event)
# Commit changes
db_session.commit()
# Update timestamp in session
db_session.refresh(storage_session)
session.last_update_time = storage_session.update_time.timestamp()
logger.info(f"Session saved: {session.id} with {len(session.events)} events")
def list_sessions(self, agent_id: str, external_id: str) -> List[Dict[str, Any]]:
"""
Lists all sessions for an agent and user.
Args:
agent_id: Agent ID
external_id: User ID
Returns:
List[Dict[str, Any]]: List of summarized sessions
"""
with self.Session() as db_session:
sessions = (
db_session.query(StorageSession)
.filter(StorageSession.app_name == agent_id)
.filter(StorageSession.user_id == external_id)
.all()
)
result = []
for session in sessions:
result.append(
{
"app_name": session.app_name,
"user_id": session.user_id,
"id": session.id,
"created_at": session.create_time.isoformat(),
"updated_at": session.update_time.isoformat(),
}
)
return result
def delete_session(self, agent_id: str, external_id: str, session_id: str) -> bool:
"""
Deletes a session from the database.
Args:
agent_id: Agent ID
external_id: User ID
session_id: Session ID
Returns:
bool: True if the session was deleted, False otherwise
"""
from sqlalchemy import delete
with self.Session() as db_session:
stmt = delete(StorageSession).where(
StorageSession.app_name == agent_id,
StorageSession.user_id == external_id,
StorageSession.id == session_id,
)
result = db_session.execute(stmt)
db_session.commit()
logger.info(f"Session deleted: {session_id}")
return result.rowcount > 0
# Utility functions compatible with ADK
def _extract_state_delta(state: dict[str, Any]):
"""Extracts state deltas between app, user, and session."""
app_state_delta = {}
user_state_delta = {}
session_state_delta = {}
if state:
for key in state.keys():
if key.startswith(State.APP_PREFIX):
app_state_delta[key.removeprefix(State.APP_PREFIX)] = state[key]
elif key.startswith(State.USER_PREFIX):
user_state_delta[key.removeprefix(State.USER_PREFIX)] = state[key]
elif not key.startswith(State.TEMP_PREFIX):
session_state_delta[key] = state[key]
return app_state_delta, user_state_delta, session_state_delta
def _merge_state(app_state, user_state, session_state):
"""Merges app, user, and session states into a single object."""
merged_state = copy.deepcopy(session_state)
for key in app_state.keys():
merged_state[State.APP_PREFIX + key] = app_state[key]
for key in user_state.keys():
merged_state[State.USER_PREFIX + key] = user_state[key]
return merged_state
def _decode_content(content: Optional[dict[str, Any]]) -> Optional[Content]:
"""Decodes event content potentially with binary data."""
if not content:
return None
for p in content.get("parts", []):
if "inline_data" in p and isinstance(p["inline_data"].get("data"), tuple):
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
return Content.model_validate(content)

View File

@@ -27,12 +27,22 @@
└──────────────────────────────────────────────────────────────────────────────┘
"""
from src.config.settings import settings
import os
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.sessions import DatabaseSessionService
from google.adk.memory import InMemoryMemoryService
from dotenv import load_dotenv
load_dotenv()
from src.services.crewai.session_service import CrewSessionService
if os.getenv("AI_ENGINE") == "crewai":
session_service = CrewSessionService(db_url=os.getenv("POSTGRES_CONNECTION_STRING"))
else:
session_service = DatabaseSessionService(
db_url=os.getenv("POSTGRES_CONNECTION_STRING")
)
# Initialize service instances
session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING)
artifacts_service = InMemoryArtifactService()
memory_service = InMemoryMemoryService()