feat(api): integrate new AI engines and update chat routes for dynamic agent handling

This commit is contained in:
Davidson Gomes 2025-05-19 15:22:37 -03:00
parent 9f176bf0e0
commit cf24a7ce5d
22 changed files with 2153 additions and 37 deletions

View File

@ -51,6 +51,9 @@ dependencies = [
"langgraph==0.4.1",
"opentelemetry-sdk==1.33.0",
"opentelemetry-exporter-otlp==1.33.0",
"mcp==1.9.0",
"crewai==0.120.1",
"crewai-tools==0.45.0",
]
[project.optional-dependencies]

View File

@ -39,6 +39,7 @@ from fastapi import (
Header,
)
from sqlalchemy.orm import Session
from src.config.settings import settings
from src.config.database import get_db
from src.core.jwt_middleware import (
get_jwt_token,
@ -49,7 +50,8 @@ from src.services import (
agent_service,
)
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse, FileData
from src.services.agent_runner import run_agent, run_agent_stream
from src.services.adk.agent_runner import run_agent as run_agent_adk, run_agent_stream
from src.services.crewai.agent_runner import run_agent as run_agent_crewai
from src.core.exceptions import AgentNotFoundError
from src.services.service_providers import (
session_service,
@ -262,7 +264,7 @@ async def websocket_chat(
@router.post(
"",
"/{agent_id}/{external_id}",
response_model=ChatResponse,
responses={
400: {"model": ErrorResponse},
@ -272,20 +274,32 @@ async def websocket_chat(
)
async def chat(
request: ChatRequest,
agent_id: str,
external_id: str,
_=Depends(get_agent_by_api_key),
db: Session = Depends(get_db),
):
try:
final_response = await run_agent(
request.agent_id,
request.external_id,
request.message,
session_service,
artifacts_service,
memory_service,
db,
files=request.files,
)
if settings.AI_ENGINE == "adk":
final_response = await run_agent_adk(
agent_id,
external_id,
request.message,
session_service,
artifacts_service,
memory_service,
db,
files=request.files,
)
elif settings.AI_ENGINE == "crewai":
final_response = await run_agent_crewai(
agent_id,
external_id,
request.message,
session_service,
db,
files=request.files,
)
return {
"response": final_response["final_response"],

View File

@ -0,0 +1,3 @@
from src.config.settings import settings
__all__ = ["settings"]

View File

@ -57,6 +57,9 @@ class Settings(BaseSettings):
"POSTGRES_CONNECTION_STRING", "postgresql://postgres:root@localhost:5432/evo_ai"
)
# AI engine settings
AI_ENGINE: str = os.getenv("AI_ENGINE", "adk")
# Logging settings
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
LOG_DIR: str = "logs"
@ -83,11 +86,11 @@ class Settings(BaseSettings):
# Email provider settings
EMAIL_PROVIDER: str = os.getenv("EMAIL_PROVIDER", "sendgrid")
# SendGrid settings
SENDGRID_API_KEY: str = os.getenv("SENDGRID_API_KEY", "")
EMAIL_FROM: str = os.getenv("EMAIL_FROM", "noreply@yourdomain.com")
# SMTP settings
SMTP_HOST: str = os.getenv("SMTP_HOST", "")
SMTP_PORT: int = int(os.getenv("SMTP_PORT", 587))
@ -96,7 +99,7 @@ class Settings(BaseSettings):
SMTP_USE_TLS: bool = os.getenv("SMTP_USE_TLS", "true").lower() == "true"
SMTP_USE_SSL: bool = os.getenv("SMTP_USE_SSL", "false").lower() == "true"
SMTP_FROM: str = os.getenv("SMTP_FROM", "")
APP_URL: str = os.getenv("APP_URL", "http://localhost:8000")
# Server settings

View File

@ -43,9 +43,11 @@ class FileData(BaseModel):
class ChatRequest(BaseModel):
"""Model to represent a chat request."""
agent_id: str = Field(..., description="Agent ID to process the message")
external_id: str = Field(..., description="External ID for user identification")
message: str = Field(..., description="User message to the agent")
agent_id: Optional[str] = Field(None, description="Agent ID to process the message")
external_id: Optional[str] = Field(
None, description="External ID for user identification"
)
files: Optional[List[FileData]] = Field(
None, description="List of files attached to the message"
)

View File

@ -1 +1 @@
from .agent_runner import run_agent
from .adk.agent_runner import run_agent

View File

@ -45,7 +45,7 @@ from src.services.agent_service import (
)
from src.services.mcp_server_service import get_mcp_server
from src.services.agent_runner import run_agent, run_agent_stream
from src.services.adk.agent_runner import run_agent, run_agent_stream
from src.services.service_providers import (
session_service,
artifacts_service,
@ -388,7 +388,6 @@ class A2ATaskManager:
self, request: SendTaskStreamingRequest, agent: Agent
) -> AsyncIterable[SendTaskStreamingResponse]:
"""Processes a task in streaming mode using the specified agent."""
# Extrair e processar arquivos da mesma forma que no método _process_task
query = self._extract_user_query(request.params)
try:
@ -448,21 +447,19 @@ class A2ATaskManager:
),
)
# Use os arquivos processados do _extract_user_query
files = getattr(self, "_last_processed_files", None)
# Log sobre os arquivos processados
if files:
logger.info(
f"Streaming: Passando {len(files)} arquivos processados para run_agent_stream"
f"Streaming: Uploading {len(files)} files to run_agent_stream"
)
for file_info in files:
logger.info(
f"Streaming: Arquivo sendo enviado: {file_info.filename} ({file_info.content_type})"
f"Streaming: File being sent: {file_info.filename} ({file_info.content_type})"
)
else:
logger.warning(
"Streaming: Nenhum arquivo processado disponível para enviar ao agente"
"Streaming: No processed files available to send to the agent"
)
async for chunk in run_agent_stream(
@ -473,7 +470,7 @@ class A2ATaskManager:
artifacts_service=artifacts_service,
memory_service=memory_service,
db=self.db,
files=files, # Passar os arquivos processados para o streaming
files=files,
):
try:
chunk_data = json.loads(chunk)

View File

@ -36,11 +36,11 @@ from src.schemas.schemas import Agent
from src.utils.logger import setup_logger
from src.core.exceptions import AgentNotFoundError
from src.services.agent_service import get_agent
from src.services.custom_tools import CustomToolBuilder
from src.services.mcp_service import MCPService
from src.services.custom_agents.a2a_agent import A2ACustomAgent
from src.services.custom_agents.workflow_agent import WorkflowAgent
from src.services.custom_agents.task_agent import TaskAgent
from src.services.adk.custom_tools import CustomToolBuilder
from src.services.adk.mcp_service import MCPService
from src.services.adk.custom_agents.a2a_agent import A2ACustomAgent
from src.services.adk.custom_agents.workflow_agent import WorkflowAgent
from src.services.adk.custom_agents.task_agent import TaskAgent
from src.services.apikey_service import get_decrypted_api_key
from sqlalchemy.orm import Session
from contextlib import AsyncExitStack

View File

@ -35,7 +35,7 @@ from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactServ
from src.utils.logger import setup_logger
from src.core.exceptions import AgentNotFoundError, InternalServerError
from src.services.agent_service import get_agent
from src.services.agent_builder import AgentBuilder
from src.services.adk.agent_builder import AgentBuilder
from sqlalchemy.orm import Session
from typing import Optional, AsyncGenerator
import asyncio

View File

@ -162,7 +162,7 @@ class TaskAgent(BaseAgent):
),
)
from src.services.agent_builder import AgentBuilder
from src.services.adk.agent_builder import AgentBuilder
print(f"Building agent in Task agent: {agent.name}")
agent_builder = AgentBuilder(self.db)

View File

@ -181,7 +181,7 @@ class WorkflowAgent(BaseAgent):
return
# Import moved to inside the function to avoid circular import
from src.services.agent_builder import AgentBuilder
from src.services.adk.agent_builder import AgentBuilder
agent_builder = AgentBuilder(self.db)
root_agent, exit_stack = await agent_builder.build_agent(agent)

View File

@ -0,0 +1,219 @@
"""
@author: Davidson Gomes
@file: agent_builder.py
Developed by: Davidson Gomes
Creation date: May 13, 2025
Contact: contato@evolution-api.com
@copyright © Evolution API 2025. All rights reserved.
Licensed under the Apache License, Version 2.0
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
@important
For any future changes to the code in this file, it is recommended to
include, together with the modification, the information of the developer
who changed it and the date of modification.
"""
from typing import List, Tuple, Optional
from src.schemas.schemas import Agent
from src.schemas.agent_config import AgentTask
from src.services.crewai.custom_tool import CustomToolBuilder
from src.services.crewai.mcp_service import MCPService
from src.utils.logger import setup_logger
from src.services.apikey_service import get_decrypted_api_key
from sqlalchemy.orm import Session
from contextlib import AsyncExitStack
from crewai import LLM, Agent as LlmAgent, Crew, Task, Process
from datetime import datetime
import uuid
logger = setup_logger(__name__)
class AgentBuilder:
def __init__(self, db: Session):
self.db = db
self.custom_tool_builder = CustomToolBuilder()
self.mcp_service = MCPService()
async def _get_api_key(self, agent: Agent) -> str:
"""Get the API key for the agent."""
api_key = None
# Get API key from api_key_id
if hasattr(agent, "api_key_id") and agent.api_key_id:
if decrypted_key := get_decrypted_api_key(self.db, agent.api_key_id):
logger.info(f"Using stored API key for agent {agent.name}")
api_key = decrypted_key
else:
logger.error(f"Stored API key not found for agent {agent.name}")
raise ValueError(
f"API key with ID {agent.api_key_id} not found or inactive"
)
else:
# Check if there is an API key in the config (temporary field)
config_api_key = agent.config.get("api_key") if agent.config else None
if config_api_key:
logger.info(f"Using config API key for agent {agent.name}")
# Check if it is a UUID of a stored key
try:
key_id = uuid.UUID(config_api_key)
if decrypted_key := get_decrypted_api_key(self.db, key_id):
logger.info("Config API key is a valid reference")
api_key = decrypted_key
else:
# Use the key directly
api_key = config_api_key
except (ValueError, TypeError):
# It is not a UUID, use directly
api_key = config_api_key
else:
logger.error(f"No API key configured for agent {agent.name}")
raise ValueError(
f"Agent {agent.name} does not have a configured API key"
)
return api_key
async def _create_llm(self, agent: Agent) -> LLM:
"""Create an LLM from the agent data."""
api_key = await self._get_api_key(agent)
return LLM(model=agent.model, api_key=api_key)
async def _create_llm_agent(
self, agent: Agent, enabled_tools: List[str] = []
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
"""Create an LLM agent from the agent data."""
# Get custom tools from the configuration
custom_tools = []
custom_tools = self.custom_tool_builder.build_tools(agent.config)
# # Get MCP tools from the configuration
mcp_tools = []
mcp_exit_stack = None
if agent.config.get("mcp_servers") or agent.config.get("custom_mcp_servers"):
try:
mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(
agent.config, self.db
)
except Exception as e:
logger.error(f"Error building MCP tools: {e}")
# Continue without MCP tools
mcp_tools = []
mcp_exit_stack = None
# # Get agent tools
# agent_tools = await self._agent_tools_builder(agent)
# Combine all tools
all_tools = custom_tools + mcp_tools
if enabled_tools:
all_tools = [tool for tool in all_tools if tool.name in enabled_tools]
logger.info(f"Enabled tools enabled. Total tools: {len(all_tools)}")
now = datetime.now()
current_datetime = now.strftime("%d/%m/%Y %H:%M")
current_day_of_week = now.strftime("%A")
current_date_iso = now.strftime("%Y-%m-%d")
current_time = now.strftime("%H:%M")
# Substitute variables in the prompt
formatted_prompt = agent.instruction.format(
current_datetime=current_datetime,
current_day_of_week=current_day_of_week,
current_date_iso=current_date_iso,
current_time=current_time,
)
llm_agent = LlmAgent(
role=agent.role,
goal=agent.goal,
backstory=formatted_prompt,
llm=await self._create_llm(agent),
tools=all_tools,
verbose=True,
cache=True,
# memory=True,
)
return llm_agent, mcp_exit_stack
async def _create_tasks(
self, agent: LlmAgent, tasks: List[AgentTask] = []
) -> List[Task]:
"""Create tasks from the agent data."""
tasks_list = []
if tasks:
tasks_list.extend(
Task(
name=task.name,
description=task.description,
expected_output=task.expected_output,
agent=agent,
verbose=True,
)
for task in tasks
)
return tasks_list
async def build_crew(self, agents: List[LlmAgent], tasks: List[Task] = []) -> Crew:
"""Create a crew from the agent data."""
return Crew(
agents=agents,
tasks=tasks,
process=Process.sequential,
verbose=True,
)
async def build_llm_agent(
self, root_agent, enabled_tools: List[str] = []
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
"""Build an LLM agent with its sub-agents."""
logger.info("Creating LLM agent")
try:
result = await self._create_llm_agent(root_agent, enabled_tools)
if isinstance(result, tuple) and len(result) == 2:
return result
else:
return result, None
except Exception as e:
logger.error(f"Error in build_llm_agent: {e}")
raise
async def build_agent(
self, root_agent, enabled_tools: List[str] = []
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
"""Build the appropriate agent based on the type of the root agent."""
if root_agent.type == "llm":
agent, exit_stack = await self.build_llm_agent(root_agent, enabled_tools)
return agent, exit_stack
elif root_agent.type == "a2a":
raise ValueError("A2A agents are not supported yet")
# return await self.build_a2a_agent(root_agent)
elif root_agent.type == "workflow":
raise ValueError("Workflow agents are not supported yet")
# return await self.build_workflow_agent(root_agent)
elif root_agent.type == "task":
raise ValueError("Task agents are not supported yet")
# return await self.build_task_agent(root_agent)
else:
raise ValueError(f"Invalid agent type: {root_agent.type}")
# return await self.build_composite_agent(root_agent)

View File

@ -0,0 +1,595 @@
"""
@author: Davidson Gomes
@file: agent_runner.py
Developed by: Davidson Gomes
Creation date: May 13, 2025
Contact: contato@evolution-api.com
@copyright © Evolution API 2025. All rights reserved.
Licensed under the Apache License, Version 2.0
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
@important
For any future changes to the code in this file, it is recommended to
include, together with the modification, the information of the developer
who changed it and the date of modification.
"""
from crewai import Crew, Task, Agent as LlmAgent
from src.services.crewai.session_service import (
CrewSessionService,
Event,
Content,
Part,
Session,
)
from src.services.crewai.agent_builder import AgentBuilder
from src.utils.logger import setup_logger
from src.core.exceptions import AgentNotFoundError, InternalServerError
from src.services.agent_service import get_agent
from sqlalchemy.orm import Session
from typing import Optional, AsyncGenerator
import asyncio
import json
from datetime import datetime
from src.utils.otel import get_tracer
from opentelemetry import trace
import base64
logger = setup_logger(__name__)
def extract_text_from_output(crew_output):
"""Extract text from CrewOutput object."""
if hasattr(crew_output, "raw") and crew_output.raw:
return crew_output.raw
elif hasattr(crew_output, "__str__"):
return str(crew_output)
# Fallback if no text found
return "Unable to extract a valid response."
async def run_agent(
agent_id: str,
external_id: str,
message: str,
session_service: CrewSessionService,
db: Session,
session_id: Optional[str] = None,
timeout: float = 60.0,
files: Optional[list] = None,
):
tracer = get_tracer()
with tracer.start_as_current_span(
"run_agent",
attributes={
"agent_id": agent_id,
"external_id": external_id,
"session_id": session_id or f"{external_id}_{agent_id}",
"message": message,
"has_files": files is not None and len(files) > 0,
},
):
exit_stack = None
try:
logger.info(
f"Starting execution of agent {agent_id} for external_id {external_id}"
)
logger.info(f"Received message: {message}")
if files and len(files) > 0:
logger.info(f"Received {len(files)} files with message")
get_root_agent = get_agent(db, agent_id)
logger.info(
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
)
if get_root_agent is None:
raise AgentNotFoundError(f"Agent with ID {agent_id} not found")
# Using the AgentBuilder to create the agent
agent_builder = AgentBuilder(db)
result = await agent_builder.build_agent(get_root_agent)
# Check how the result is structured
if isinstance(result, tuple) and len(result) == 2:
root_agent, exit_stack = result
else:
# If the result is not a tuple of 2 elements
root_agent = result
exit_stack = None
logger.warning("build_agent did not return an exit_stack")
# TODO: files should be processed here
# Fetch session information
crew_session_id = f"{external_id}_{agent_id}"
if session_id is None:
session_id = crew_session_id
logger.info(f"Searching session for external_id {external_id}")
try:
session = session_service.get_session(
agent_id=agent_id,
external_id=external_id,
session_id=crew_session_id,
)
except Exception as e:
logger.warning(f"Error getting session: {str(e)}")
session = None
if session is None:
logger.info(f"Creating new session for external_id {external_id}")
session = session_service.create_session(
agent_id=agent_id,
external_id=external_id,
session_id=crew_session_id,
)
# Add user message to session
session.events.append(
Event(
author="user",
content=Content(parts=[{"text": message}]),
timestamp=datetime.now().timestamp(),
)
)
# Save session to database
session_service.save_session(session)
# Build message history for context
conversation_history = []
if session and session.events:
for event in session.events:
if event.author and event.content and event.content.parts:
for part in event.content.parts:
if isinstance(part, dict) and "text" in part:
role = "User" if event.author == "user" else "Assistant"
conversation_history.append(f"{role}: {part['text']}")
# Build description with history as context
task_description = (
f"Conversation history:\n" + "\n".join(conversation_history)
if conversation_history
else ""
)
task_description += f"\n\nCurrent user message: {message}"
task = Task(
name="resolve_user_request",
description=task_description,
expected_output="Response to the user request",
agent=root_agent,
verbose=True,
)
crew = await agent_builder.build_crew([root_agent], [task])
# Use normal kickoff or kickoff_async instead of kickoff_for_each
if hasattr(crew, "kickoff_async"):
crew_output = await crew.kickoff_async(inputs={"message": message})
else:
loop = asyncio.get_event_loop()
crew_output = await loop.run_in_executor(
None, lambda: crew.kickoff(inputs={"message": message})
)
# Extract response and add to session
final_text = extract_text_from_output(crew_output)
# Add agent response as event in session
session.events.append(
Event(
author=get_root_agent.name,
content=Content(parts=[{"text": final_text}]),
timestamp=datetime.now().timestamp(),
)
)
# Save session with new event
session_service.save_session(session)
logger.info("Starting agent execution")
final_response_text = "No final response captured."
message_history = []
try:
response_queue = asyncio.Queue()
execution_completed = asyncio.Event()
async def process_events():
try:
# Log the result
logger.info(f"Crew output: {crew_output}")
# Signal that execution is complete
execution_completed.set()
# Extract text from CrewOutput object
final_text = "Unable to extract a valid response."
if hasattr(crew_output, "raw") and crew_output.raw:
final_text = crew_output.raw
elif hasattr(crew_output, "__str__"):
final_text = str(crew_output)
# If still empty or None, check crew artifacts
if not final_text or final_text.strip() == "":
# Try to get from agent messages
if hasattr(root_agent, "messages") and root_agent.messages:
# Get the last message from the agent
for msg in reversed(root_agent.messages):
if hasattr(msg, "content") and msg.content:
final_text = msg.content
break
# If still empty, use a fallback
if not final_text or final_text.strip() == "":
final_text = "The agent could not produce a valid response. Please try again with a different question."
# Put the extracted text in the queue
await response_queue.put(final_text)
except Exception as e:
logger.error(f"Error in process_events: {str(e)}")
# Provide a more helpful error response
error_response = f"An error occurred during processing: {str(e)}\n\nIf you are trying to use external tools such as Brave Search, please make sure the connection is working properly."
await response_queue.put(error_response)
execution_completed.set()
task = asyncio.create_task(process_events())
try:
wait_task = asyncio.create_task(execution_completed.wait())
done, pending = await asyncio.wait({wait_task}, timeout=timeout)
for p in pending:
p.cancel()
if not execution_completed.is_set():
logger.warning(
f"Agent execution timed out after {timeout} seconds"
)
await response_queue.put(
"The response took too long and was interrupted."
)
final_response_text = await response_queue.get()
except Exception as e:
logger.error(f"Error waiting for response: {str(e)}")
final_response_text = f"Error processing response: {str(e)}"
# Add the session to memory after completion
# completed_session = session_service.get_session(
# app_name=agent_id,
# user_id=external_id,
# session_id=crew_session_id,
# )
# memory_service.add_session_to_memory(completed_session)
# Cancel the processing task if it is still running
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
logger.info("Task cancelled successfully")
except Exception as e:
logger.error(f"Error cancelling task: {str(e)}")
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
raise InternalServerError(str(e)) from e
logger.info("Agent execution completed successfully")
return {
"final_response": final_response_text,
"message_history": message_history,
}
except AgentNotFoundError as e:
logger.error(f"Error processing request: {str(e)}")
raise e
except Exception as e:
logger.error(f"Internal error processing request: {str(e)}", exc_info=True)
raise InternalServerError(str(e))
finally:
# Clean up MCP connection - MUST be executed in the same task
if exit_stack:
logger.info("Closing MCP server connection...")
try:
if hasattr(exit_stack, "aclose"):
# If it's an AsyncExitStack
await exit_stack.aclose()
elif isinstance(exit_stack, list):
# If it's a list of adapters
for adapter in exit_stack:
if hasattr(adapter, "close"):
adapter.close()
except Exception as e:
logger.error(f"Error closing MCP connection: {e}")
# Do not raise the exception to not obscure the original error
def convert_sets(obj):
if isinstance(obj, set):
return list(obj)
elif isinstance(obj, dict):
return {k: convert_sets(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_sets(i) for i in obj]
else:
return obj
async def run_agent_stream(
agent_id: str,
external_id: str,
message: str,
db: Session,
session_id: Optional[str] = None,
files: Optional[list] = None,
) -> AsyncGenerator[str, None]:
tracer = get_tracer()
span = tracer.start_span(
"run_agent_stream",
attributes={
"agent_id": agent_id,
"external_id": external_id,
"session_id": session_id or f"{external_id}_{agent_id}",
"message": message,
"has_files": files is not None and len(files) > 0,
},
)
exit_stack = None
try:
with trace.use_span(span, end_on_exit=True):
try:
logger.info(
f"Starting streaming execution of agent {agent_id} for external_id {external_id}"
)
logger.info(f"Received message: {message}")
if files and len(files) > 0:
logger.info(f"Received {len(files)} files with message")
get_root_agent = get_agent(db, agent_id)
logger.info(
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
)
if get_root_agent is None:
raise AgentNotFoundError(f"Agent with ID {agent_id} not found")
# Using the AgentBuilder to create the agent
agent_builder = AgentBuilder(db)
result = await agent_builder.build_agent(get_root_agent)
# Check how the result is structured
if isinstance(result, tuple) and len(result) == 2:
root_agent, exit_stack = result
else:
# If the result is not a tuple of 2 elements
root_agent = result
exit_stack = None
logger.warning("build_agent did not return an exit_stack")
# TODO: files should be processed here
# Fetch session history if available
session_id = f"{external_id}_{agent_id}"
# Create an instance of the session service
try:
from src.config.settings import get_settings
settings = get_settings()
db_url = settings.DATABASE_URL
except ImportError:
# Fallback to local SQLite if cannot import settings
db_url = "sqlite:///data/crew_sessions.db"
session_service = CrewSessionService(db_url)
try:
# Try to get existing session
session = session_service.get_session(
agent_id=agent_id,
external_id=external_id,
session_id=session_id,
)
except Exception as e:
logger.warning(f"Could not load session: {e}")
session = None
# Build message history for context
conversation_history = []
if session and session.events:
for event in session.events:
if event.author and event.content and event.content.parts:
for part in event.content.parts:
if isinstance(part, dict) and "text" in part:
role = (
"User"
if event.author == "user"
else "Assistant"
)
conversation_history.append(
f"{role}: {part['text']}"
)
# Build description with history
task_description = (
f"Conversation history:\n" + "\n".join(conversation_history)
if conversation_history
else ""
)
task_description += f"\n\nCurrent user message: {message}"
task = Task(
name="resolve_user_request",
description=task_description,
expected_output="Response to the user request",
agent=root_agent,
verbose=True,
)
crew = await agent_builder.build_crew([root_agent], [task])
logger.info("Starting agent streaming execution")
try:
# Check if we can process messages with kickoff_for_each
if hasattr(crew, "kickoff_for_each"):
# Create input with current message
inputs = [{"message": message}]
logger.info(
f"Using kickoff_for_each for streaming with {len(inputs)} input(s)"
)
# Execute kickoff_for_each
results = crew.kickoff_for_each(inputs=inputs)
# Print results and save to session
for i, result in enumerate(results):
logger.info(f"Result of event {i+1}: {result}")
# If we have a session, save the response to it
if session:
# Add agent response as event
session.events.append(
Event(
author="agent",
content=Content(parts=[{"text": result}]),
timestamp=datetime.now().timestamp(),
)
)
# Save current session with new message
if session:
# Also add user message if it doesn't exist yet
if not any(
e.author == "user"
and any(
p.get("text") == message for p in e.content.parts
)
for e in session.events
if e.content and e.content.parts
):
session.events.append(
Event(
author="user",
content=Content(parts=[{"text": message}]),
timestamp=datetime.now().timestamp(),
)
)
# Save session
try:
session_service.save_session(session)
logger.info(f"Session saved successfully: {session_id}")
except Exception as e:
logger.error(f"Error saving session: {e}")
# Use last result as final output
crew_output = results[-1] if results else None
else:
# CrewAI kickoff method is synchronous, fallback if kickoff_for_each not available
logger.info(
"kickoff_for_each not available, using standard kickoff for streaming"
)
crew_output = crew.kickoff()
logger.info(f"Crew output: {crew_output}")
# Extract the actual text content
if hasattr(crew_output, "raw") and crew_output.raw:
final_output = crew_output.raw
elif hasattr(crew_output, "__str__"):
final_output = str(crew_output)
else:
final_output = "Could not extract text from response"
# Save response to session (for fallback case of normal kickoff)
if session and not hasattr(crew, "kickoff_for_each"):
# Add agent response
session.events.append(
Event(
author="agent",
content=Content(parts=[{"text": final_output}]),
timestamp=datetime.now().timestamp(),
)
)
# Add user message if it doesn't exist yet
if not any(
e.author == "user"
and any(p.get("text") == message for p in e.content.parts)
for e in session.events
if e.content and e.content.parts
):
session.events.append(
Event(
author="user",
content=Content(parts=[{"text": message}]),
timestamp=datetime.now().timestamp(),
)
)
# Save session
try:
session_service.save_session(session)
logger.info(
f"Session saved successfully (method: kickoff): {session_id}"
)
except Exception as e:
logger.error(f"Error saving session: {e}")
yield json.dumps({"text": final_output})
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
raise InternalServerError(str(e)) from e
finally:
# Clean up MCP connection
if exit_stack:
logger.info("Closing MCP server connection...")
try:
if hasattr(exit_stack, "aclose"):
# If it's an AsyncExitStack
await exit_stack.aclose()
elif isinstance(exit_stack, list):
# If it's a list of adapters
for adapter in exit_stack:
if hasattr(adapter, "close"):
adapter.close()
except Exception as e:
logger.error(f"Error closing MCP connection: {e}")
# Do not raise the exception to not obscure the original error
logger.info("Agent streaming execution completed successfully")
except AgentNotFoundError as e:
logger.error(f"Error processing request: {str(e)}")
raise InternalServerError(str(e)) from e
except Exception as e:
logger.error(
f"Internal error processing request: {str(e)}", exc_info=True
)
raise InternalServerError(str(e))
finally:
span.end()

View File

@ -0,0 +1,369 @@
"""
@author: Davidson Gomes
@file: custom_tool.py
Developed by: Davidson Gomes
Creation date: May 13, 2025
Contact: contato@evolution-api.com
@copyright © Evolution API 2025. All rights reserved.
Licensed under the Apache License, Version 2.0
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
@important
For any future changes to the code in this file, it is recommended to
include, together with the modification, the information of the developer
who changed it and the date of modification.
"""
from typing import Any, Dict, List, Type
from crewai.tools import BaseTool, tool
import requests
import json
from src.utils.logger import setup_logger
from pydantic import BaseModel, Field, create_model
logger = setup_logger(__name__)
class CustomToolBuilder:
def __init__(self):
self.tools = []
def _create_http_tool(self, tool_config: Dict[str, Any]) -> BaseTool:
"""Create an HTTP tool based on the provided configuration."""
# Extract configuration parameters
name = tool_config["name"]
description = tool_config["description"]
endpoint = tool_config["endpoint"]
method = tool_config["method"]
headers = tool_config.get("headers", {})
parameters = tool_config.get("parameters", {}) or {}
values = tool_config.get("values", {})
error_handling = tool_config.get("error_handling", {})
path_params = parameters.get("path_params") or {}
query_params = parameters.get("query_params") or {}
body_params = parameters.get("body_params") or {}
# Dynamic creation of the input schema for the tool
field_definitions = {}
# Add all parameters as fields
for param in (
list(path_params.keys())
+ list(query_params.keys())
+ list(body_params.keys())
):
# Default to string type for all parameters
field_definitions[param] = (
str,
Field(..., description=f"Parameter {param}"),
)
# If there are no parameters but default values, use those as optional fields
if not field_definitions and values:
for param, value in values.items():
param_type = type(value)
field_definitions[param] = (
param_type,
Field(default=value, description=f"Parameter {param}"),
)
# Create dynamic input schema model in line with the documentation
tool_input_model = create_model(
f"{name.replace(' ', '')}Input", **field_definitions
)
# Create the HTTP tool using crewai's BaseTool class
# Following the pattern in the documentation
def create_http_tool_class():
# Capture variables from outer scope
_name = name
_description = description
_tool_input_model = tool_input_model
class HttpTool(BaseTool):
name: str = _name
description: str = _description
args_schema: Type[BaseModel] = _tool_input_model
def _run(self, **kwargs):
"""Execute the HTTP request and return the result."""
try:
# Combines default values with provided values
all_values = {**values, **kwargs}
# Substitutes placeholders in headers
processed_headers = {
k: v.format(**all_values) if isinstance(v, str) else v
for k, v in headers.items()
}
# Processes path parameters
url = endpoint
for param, value in path_params.items():
if param in all_values:
url = url.replace(
f"{{{param}}}", str(all_values[param])
)
# Process query parameters
query_params_dict = {}
for param, value in query_params.items():
if isinstance(value, list):
# If the value is a list, join with comma
query_params_dict[param] = ",".join(value)
elif param in all_values:
# If the parameter is in the values, use the provided value
query_params_dict[param] = all_values[param]
else:
# Otherwise, use the default value from the configuration
query_params_dict[param] = value
# Adds default values to query params if they are not present
for param, value in values.items():
if (
param not in query_params_dict
and param not in path_params
):
query_params_dict[param] = value
body_data = {}
for param, param_config in body_params.items():
if param in all_values:
body_data[param] = all_values[param]
# Adds default values to body if they are not present
for param, value in values.items():
if (
param not in body_data
and param not in query_params_dict
and param not in path_params
):
body_data[param] = value
# Makes the HTTP request
response = requests.request(
method=method,
url=url,
headers=processed_headers,
params=query_params_dict,
json=body_data or None,
timeout=error_handling.get("timeout", 30),
)
if response.status_code >= 400:
raise requests.exceptions.HTTPError(
f"Error in the request: {response.status_code} - {response.text}"
)
# Always returns the response as a string
return json.dumps(response.json())
except Exception as e:
logger.error(f"Error executing tool {name}: {str(e)}")
return json.dumps(
error_handling.get(
"fallback_response",
{"error": "tool_execution_error", "message": str(e)},
)
)
return HttpTool
# Create the tool instance
HttpToolClass = create_http_tool_class()
http_tool = HttpToolClass()
# Add cache function following the documentation
def http_cache_function(arguments: dict, result: str) -> bool:
"""Determines whether to cache the result based on arguments and result."""
# Default implementation: cache all successful results
try:
# If the result is parseable JSON and not an error, cache it
result_obj = json.loads(result)
return not (isinstance(result_obj, dict) and "error" in result_obj)
except Exception:
# If result is not valid JSON, don't cache
return False
# Assign the cache function to the tool
http_tool.cache_function = http_cache_function
return http_tool
def _create_http_tool_with_decorator(self, tool_config: Dict[str, Any]) -> Any:
"""Create an HTTP tool using the tool decorator."""
# Extract configuration parameters
name = tool_config["name"]
description = tool_config["description"]
endpoint = tool_config["endpoint"]
method = tool_config["method"]
headers = tool_config.get("headers", {})
parameters = tool_config.get("parameters", {}) or {}
values = tool_config.get("values", {})
error_handling = tool_config.get("error_handling", {})
path_params = parameters.get("path_params") or {}
query_params = parameters.get("query_params") or {}
body_params = parameters.get("body_params") or {}
# Create function docstring with parameter documentation
param_list = (
list(path_params.keys())
+ list(query_params.keys())
+ list(body_params.keys())
)
doc_params = []
for param in param_list:
doc_params.append(f" {param}: Parameter description")
docstring = (
f"{description}\n\nParameters:\n"
+ "\n".join(doc_params)
+ "\n\nReturns:\n String containing the response in JSON format"
)
# Create the tool function using the decorator pattern in the documentation
@tool(name=name)
def http_tool(**kwargs):
"""Tool function created dynamically."""
try:
# Combines default values with provided values
all_values = {**values, **kwargs}
# Substitutes placeholders in headers
processed_headers = {
k: v.format(**all_values) if isinstance(v, str) else v
for k, v in headers.items()
}
# Processes path parameters
url = endpoint
for param, value in path_params.items():
if param in all_values:
url = url.replace(f"{{{param}}}", str(all_values[param]))
# Process query parameters
query_params_dict = {}
for param, value in query_params.items():
if isinstance(value, list):
# If the value is a list, join with comma
query_params_dict[param] = ",".join(value)
elif param in all_values:
# If the parameter is in the values, use the provided value
query_params_dict[param] = all_values[param]
else:
# Otherwise, use the default value from the configuration
query_params_dict[param] = value
# Adds default values to query params if they are not present
for param, value in values.items():
if param not in query_params_dict and param not in path_params:
query_params_dict[param] = value
body_data = {}
for param, param_config in body_params.items():
if param in all_values:
body_data[param] = all_values[param]
# Adds default values to body if they are not present
for param, value in values.items():
if (
param not in body_data
and param not in query_params_dict
and param not in path_params
):
body_data[param] = value
# Makes the HTTP request
response = requests.request(
method=method,
url=url,
headers=processed_headers,
params=query_params_dict,
json=body_data or None,
timeout=error_handling.get("timeout", 30),
)
if response.status_code >= 400:
raise requests.exceptions.HTTPError(
f"Error in the request: {response.status_code} - {response.text}"
)
# Always returns the response as a string
return json.dumps(response.json())
except Exception as e:
logger.error(f"Error executing tool {name}: {str(e)}")
return json.dumps(
error_handling.get(
"fallback_response",
{"error": "tool_execution_error", "message": str(e)},
)
)
# Replace the docstring
http_tool.__doc__ = docstring
# Add cache function following the documentation
def http_cache_function(arguments: dict, result: str) -> bool:
"""Determines whether to cache the result based on arguments and result."""
# Default implementation: cache all successful results
try:
# If the result is parseable JSON and not an error, cache it
result_obj = json.loads(result)
return not (isinstance(result_obj, dict) and "error" in result_obj)
except Exception:
# If result is not valid JSON, don't cache
return False
# Assign the cache function to the tool
http_tool.cache_function = http_cache_function
return http_tool
def build_tools(self, tools_config: Dict[str, Any]) -> List[BaseTool]:
"""Builds a list of tools based on the provided configuration. Accepts both 'tools' and 'custom_tools' (with http_tools)."""
self.tools = []
# Find HTTP tools configuration in various possible locations
http_tools = []
if tools_config.get("http_tools"):
http_tools = tools_config.get("http_tools", [])
elif tools_config.get("custom_tools") and tools_config["custom_tools"].get(
"http_tools"
):
http_tools = tools_config["custom_tools"].get("http_tools", [])
elif (
tools_config.get("tools")
and isinstance(tools_config["tools"], dict)
and tools_config["tools"].get("http_tools")
):
http_tools = tools_config["tools"].get("http_tools", [])
# Determine which implementation method to use (BaseTool or decorator)
use_decorator = tools_config.get("use_decorator", False)
# Create tools for each HTTP tool configuration
for http_tool_config in http_tools:
if use_decorator:
self.tools.append(
self._create_http_tool_with_decorator(http_tool_config)
)
else:
self.tools.append(self._create_http_tool(http_tool_config))
return self.tools

View File

@ -0,0 +1,264 @@
"""
@author: Davidson Gomes
@file: mcp_service.py
Developed by: Davidson Gomes
Creation date: May 13, 2025
Contact: contato@evolution-api.com
@copyright © Evolution API 2025. All rights reserved.
Licensed under the Apache License, Version 2.0
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
@important
For any future changes to the code in this file, it is recommended to
include, together with the modification, the information of the developer
who changed it and the date of modification.
"""
from typing import Any, Dict, List, Optional, Tuple
from contextlib import ExitStack
import os
import sys
from src.utils.logger import setup_logger
from src.services.mcp_server_service import get_mcp_server
from sqlalchemy.orm import Session
try:
from crewai_tools import MCPServerAdapter
from mcp import StdioServerParameters
HAS_MCP_PACKAGES = True
except ImportError:
logger = setup_logger(__name__)
logger.error(
"MCP packages are not installed. Please install mcp and crewai-tools[mcp]"
)
HAS_MCP_PACKAGES = False
logger = setup_logger(__name__)
class MCPService:
def __init__(self):
self.tools = []
self.exit_stack = ExitStack()
def _connect_to_mcp_server(
self, server_config: Dict[str, Any]
) -> Tuple[List[Any], Optional[ExitStack]]:
"""Connect to a specific MCP server and return its tools."""
if not HAS_MCP_PACKAGES:
logger.error("Cannot connect to MCP server: MCP packages not installed")
return [], None
try:
# Determines the type of server (local or remote)
if "url" in server_config:
# Remote server (SSE) - Simplified approach using direct dictionary
sse_config = {"url": server_config["url"]}
# Add headers if provided
if "headers" in server_config and server_config["headers"]:
sse_config["headers"] = server_config["headers"]
# Create the MCPServerAdapter with the SSE configuration
mcp_adapter = MCPServerAdapter(sse_config)
else:
# Local server (Stdio)
command = server_config.get("command", "npx")
args = server_config.get("args", [])
# Adds environment variables if specified
env = server_config.get("env", {})
if env:
for key, value in env.items():
os.environ[key] = value
connection_params = StdioServerParameters(
command=command, args=args, env=env
)
# Create the MCPServerAdapter with the Stdio connection parameters
mcp_adapter = MCPServerAdapter(connection_params)
# Get tools from the adapter
tools = mcp_adapter.tools
# Return tools and the adapter (which serves as an exit stack)
return tools, mcp_adapter
except Exception as e:
logger.error(f"Error connecting to MCP server: {e}")
return [], None
def _filter_incompatible_tools(self, tools: List[Any]) -> List[Any]:
"""Filters incompatible tools with the model."""
problematic_tools = [
"create_pull_request_review", # This tool causes the 400 INVALID_ARGUMENT error
]
filtered_tools = []
removed_count = 0
for tool in tools:
if tool.name in problematic_tools:
logger.warning(f"Removing incompatible tool: {tool.name}")
removed_count += 1
else:
filtered_tools.append(tool)
if removed_count > 0:
logger.warning(f"Removed {removed_count} incompatible tools.")
return filtered_tools
def _filter_tools_by_agent(
self, tools: List[Any], agent_tools: List[str]
) -> List[Any]:
"""Filters tools compatible with the agent."""
if not agent_tools:
return tools
filtered_tools = []
for tool in tools:
logger.info(f"Tool: {tool.name}")
if tool.name in agent_tools:
filtered_tools.append(tool)
return filtered_tools
async def build_tools(
self, mcp_config: Dict[str, Any], db: Session
) -> Tuple[List[Any], Any]:
"""Builds a list of tools from multiple MCP servers."""
if not HAS_MCP_PACKAGES:
logger.error("Cannot build MCP tools: MCP packages not installed")
return [], None
self.tools = []
self.exit_stack = ExitStack()
adapter_list = []
try:
mcp_servers = mcp_config.get("mcp_servers", [])
if mcp_servers is not None:
# Process each MCP server in the configuration
for server in mcp_servers:
try:
# Search for the MCP server in the database
mcp_server = get_mcp_server(db, server["id"])
if not mcp_server:
logger.warning(f"MCP Server not found: {server['id']}")
continue
# Prepares the server configuration
server_config = mcp_server.config_json.copy()
# Replaces the environment variables in the config_json
if "env" in server_config and server_config["env"] is not None:
for key, value in server_config["env"].items():
if value and value.startswith("env@@"):
env_key = value.replace("env@@", "")
if server.get("envs") and env_key in server.get(
"envs", {}
):
server_config["env"][key] = server["envs"][
env_key
]
else:
logger.warning(
f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}"
)
continue
logger.info(f"Connecting to MCP server: {mcp_server.name}")
tools, adapter = self._connect_to_mcp_server(server_config)
if tools and adapter:
# Filters incompatible tools
filtered_tools = self._filter_incompatible_tools(tools)
# Filters tools compatible with the agent
if agent_tools := server.get("tools", []):
filtered_tools = self._filter_tools_by_agent(
filtered_tools, agent_tools
)
self.tools.extend(filtered_tools)
# Add to the adapter list for cleanup later
adapter_list.append(adapter)
logger.info(
f"MCP Server {mcp_server.name} connected successfully. Added {len(filtered_tools)} tools."
)
else:
logger.warning(
f"Failed to connect or no tools available for {mcp_server.name}"
)
except Exception as e:
logger.error(
f"Error connecting to MCP server {server.get('id', 'unknown')}: {e}"
)
continue
custom_mcp_servers = mcp_config.get("custom_mcp_servers", [])
if custom_mcp_servers is not None:
# Process custom MCP servers
for server in custom_mcp_servers:
if not server:
logger.warning(
"Empty server configuration found in custom_mcp_servers"
)
continue
try:
logger.info(
f"Connecting to custom MCP server: {server.get('url', 'unknown')}"
)
tools, adapter = self._connect_to_mcp_server(server)
if tools:
self.tools.extend(tools)
else:
logger.warning("No tools returned from custom MCP server")
continue
if adapter:
adapter_list.append(adapter)
logger.info(
f"Custom MCP server connected successfully. Added {len(tools)} tools."
)
else:
logger.warning("No adapter returned from custom MCP server")
except Exception as e:
logger.error(
f"Error connecting to custom MCP server {server.get('url', 'unknown')}: {e}"
)
continue
logger.info(
f"MCP Toolset created successfully. Total of {len(self.tools)} tools."
)
except Exception as e:
# Ensure cleanup
for adapter in adapter_list:
if hasattr(adapter, "close"):
adapter.close()
logger.error(f"Fatal error connecting to MCP servers: {e}")
# Return empty lists in case of error
return [], None
# Return the tools and the adapter list for cleanup
return self.tools, adapter_list

View File

@ -0,0 +1,637 @@
"""
@author: Davidson Gomes
@file: session_service.py
Developed by: Davidson Gomes
Creation date: May 13, 2025
Contact: contato@evolution-api.com
@copyright © Evolution API 2025. All rights reserved.
Licensed under the Apache License, Version 2.0
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
@important
For any future changes to the code in this file, it is recommended to
include, together with the modification, the information of the developer
who changed it and the date of modification.
"""
from datetime import datetime
import json
import uuid
import base64
import copy
from typing import Any, Dict, List, Optional, Union, Set
from sqlalchemy import create_engine, Boolean, Text, ForeignKeyConstraint
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import (
sessionmaker,
relationship,
DeclarativeBase,
Mapped,
mapped_column,
)
from sqlalchemy.sql import func
from sqlalchemy.types import DateTime, PickleType, String
from sqlalchemy.dialects import postgresql
from sqlalchemy.types import TypeDecorator
from pydantic import BaseModel
from src.utils.logger import setup_logger
logger = setup_logger(__name__)
class DynamicJSON(TypeDecorator):
"""JSON type compatible with ADK that uses JSONB in PostgreSQL and TEXT with JSON
serialization for other databases."""
impl = Text # Default implementation is TEXT
def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(postgresql.JSONB)
else:
return dialect.type_descriptor(Text)
def process_bind_param(self, value, dialect):
if value is not None:
if dialect.name == "postgresql":
return value
else:
return json.dumps(value)
return value
def process_result_value(self, value, dialect):
if value is not None:
if dialect.name == "postgresql":
return value
else:
return json.loads(value)
return value
class Base(DeclarativeBase):
"""Base class for database tables."""
pass
class StorageSession(Base):
"""Represents a session stored in the database, compatible with ADK."""
__tablename__ = "sessions"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
id: Mapped[str] = mapped_column(
String, primary_key=True, default=lambda: str(uuid.uuid4())
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
create_time: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
update_time: Mapped[DateTime] = mapped_column(
DateTime(), default=func.now(), onupdate=func.now()
)
storage_events: Mapped[list["StorageEvent"]] = relationship(
"StorageEvent",
back_populates="storage_session",
)
def __repr__(self):
return f"<StorageSession(id={self.id}, update_time={self.update_time})>"
class StorageEvent(Base):
"""Represents an event stored in the database, compatible with ADK."""
__tablename__ = "events"
id: Mapped[str] = mapped_column(String, primary_key=True)
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
session_id: Mapped[str] = mapped_column(String, primary_key=True)
invocation_id: Mapped[str] = mapped_column(String)
author: Mapped[str] = mapped_column(String)
branch: Mapped[str] = mapped_column(String, nullable=True)
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
Text, nullable=True
)
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(
DynamicJSON, nullable=True
)
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
error_code: Mapped[str] = mapped_column(String, nullable=True)
error_message: Mapped[str] = mapped_column(String, nullable=True)
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
storage_session: Mapped[StorageSession] = relationship(
"StorageSession",
back_populates="storage_events",
)
__table_args__ = (
ForeignKeyConstraint(
["app_name", "user_id", "session_id"],
["sessions.app_name", "sessions.user_id", "sessions.id"],
ondelete="CASCADE",
),
)
@property
def long_running_tool_ids(self) -> set[str]:
return (
set(json.loads(self.long_running_tool_ids_json))
if self.long_running_tool_ids_json
else set()
)
@long_running_tool_ids.setter
def long_running_tool_ids(self, value: set[str]):
if value is None:
self.long_running_tool_ids_json = None
else:
self.long_running_tool_ids_json = json.dumps(list(value))
class StorageAppState(Base):
"""Represents an application state stored in the database, compatible with ADK."""
__tablename__ = "app_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
update_time: Mapped[DateTime] = mapped_column(
DateTime(), default=func.now(), onupdate=func.now()
)
class StorageUserState(Base):
"""Represents a user state stored in the database, compatible with ADK."""
__tablename__ = "user_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
update_time: Mapped[DateTime] = mapped_column(
DateTime(), default=func.now(), onupdate=func.now()
)
# Pydantic model classes compatible with ADK
class State:
"""Utility class for states, compatible with ADK."""
APP_PREFIX = "app:"
USER_PREFIX = "user:"
TEMP_PREFIX = "temp:"
class Content(BaseModel):
"""Event content model, compatible with ADK."""
parts: List[Dict[str, Any]]
class Part(BaseModel):
"""Content part model, compatible with ADK."""
text: Optional[str] = None
class Event(BaseModel):
"""Event model, compatible with ADK."""
id: Optional[str] = None
author: str
branch: Optional[str] = None
invocation_id: Optional[str] = None
content: Optional[Content] = None
actions: Optional[Dict[str, Any]] = None
timestamp: Optional[float] = None
long_running_tool_ids: Optional[Set[str]] = None
grounding_metadata: Optional[Dict[str, Any]] = None
partial: Optional[bool] = None
turn_complete: Optional[bool] = None
error_code: Optional[str] = None
error_message: Optional[str] = None
interrupted: Optional[bool] = None
class Session(BaseModel):
"""Session model, compatible with ADK."""
app_name: str
user_id: str
id: str
state: Dict[str, Any] = {}
events: List[Event] = []
last_update_time: float
class Config:
arbitrary_types_allowed = True
class CrewSessionService:
"""Service for managing CrewAI agent sessions using ADK tables."""
def __init__(self, db_url: str):
"""
Initializes the session service.
Args:
db_url: Database connection URL.
"""
try:
self.engine = create_engine(db_url)
except Exception as e:
raise ValueError(f"Failed to create database engine: {e}")
# Create all tables
Base.metadata.create_all(self.engine)
self.Session = sessionmaker(bind=self.engine)
logger.info(f"CrewSessionService started with database at {db_url}")
def create_session(
self, agent_id: str, external_id: str, session_id: Optional[str] = None
) -> Session:
"""
Creates a new session for an agent.
Args:
agent_id: Agent ID (used as app_name in ADK)
external_id: External ID (used as user_id in ADK)
session_id: Optional session ID
Returns:
Session: The created session
"""
session_id = session_id or str(uuid.uuid4())
with self.Session() as db_session:
# Check if app and user states already exist
storage_app_state = db_session.get(StorageAppState, (agent_id))
storage_user_state = db_session.get(
StorageUserState, (agent_id, external_id)
)
app_state = storage_app_state.state if storage_app_state else {}
user_state = storage_user_state.state if storage_user_state else {}
# Create states if they don't exist
if not storage_app_state:
storage_app_state = StorageAppState(app_name=agent_id, state={})
db_session.add(storage_app_state)
if not storage_user_state:
storage_user_state = StorageUserState(
app_name=agent_id, user_id=external_id, state={}
)
db_session.add(storage_user_state)
# Create session
storage_session = StorageSession(
app_name=agent_id,
user_id=external_id,
id=session_id,
state={},
)
db_session.add(storage_session)
db_session.commit()
# Get timestamp
db_session.refresh(storage_session)
# Merge states for response
merged_state = _merge_state(app_state, user_state, {})
# Create Session object for return
session = Session(
app_name=agent_id,
user_id=external_id,
id=session_id,
state=merged_state,
last_update_time=storage_session.update_time.timestamp(),
)
logger.info(
f"Session created: {session_id} for agent {agent_id} and user {external_id}"
)
return session
def get_session(
self, agent_id: str, external_id: str, session_id: str
) -> Optional[Session]:
"""
Retrieves a session from the database.
Args:
agent_id: Agent ID
external_id: User ID
session_id: Session ID
Returns:
Optional[Session]: The retrieved session or None if not found
"""
with self.Session() as db_session:
storage_session = db_session.get(
StorageSession, (agent_id, external_id, session_id)
)
if storage_session is None:
return None
# Fetch session events
storage_events = (
db_session.query(StorageEvent)
.filter(StorageEvent.session_id == storage_session.id)
.filter(StorageEvent.app_name == agent_id)
.filter(StorageEvent.user_id == external_id)
.all()
)
# Fetch states
storage_app_state = db_session.get(StorageAppState, (agent_id))
storage_user_state = db_session.get(
StorageUserState, (agent_id, external_id)
)
app_state = storage_app_state.state if storage_app_state else {}
user_state = storage_user_state.state if storage_user_state else {}
session_state = storage_session.state
# Merge states
merged_state = _merge_state(app_state, user_state, session_state)
# Create session
session = Session(
app_name=agent_id,
user_id=external_id,
id=session_id,
state=merged_state,
last_update_time=storage_session.update_time.timestamp(),
)
# Add events
session.events = [
Event(
id=e.id,
author=e.author,
branch=e.branch,
invocation_id=e.invocation_id,
content=_decode_content(e.content),
actions=e.actions,
timestamp=e.timestamp.timestamp(),
long_running_tool_ids=e.long_running_tool_ids,
grounding_metadata=e.grounding_metadata,
partial=e.partial,
turn_complete=e.turn_complete,
error_code=e.error_code,
error_message=e.error_message,
interrupted=e.interrupted,
)
for e in storage_events
]
return session
def save_session(self, session: Session) -> None:
"""
Saves a session to the database.
Args:
session: The session to save
"""
with self.Session() as db_session:
storage_session = db_session.get(
StorageSession, (session.app_name, session.user_id, session.id)
)
if not storage_session:
logger.error(f"Session not found: {session.id}")
return
# Check states
storage_app_state = db_session.get(StorageAppState, (session.app_name))
storage_user_state = db_session.get(
StorageUserState, (session.app_name, session.user_id)
)
# Extract state deltas
app_state_delta = {}
user_state_delta = {}
session_state_delta = {}
# Apply state deltas
if storage_app_state and app_state_delta:
storage_app_state.state.update(app_state_delta)
if storage_user_state and user_state_delta:
storage_user_state.state.update(user_state_delta)
storage_session.state.update(session_state_delta)
# Save new events
for event in session.events:
# Check if event already exists
existing_event = (
(
db_session.query(StorageEvent)
.filter(StorageEvent.id == event.id)
.filter(StorageEvent.app_name == session.app_name)
.filter(StorageEvent.user_id == session.user_id)
.filter(StorageEvent.session_id == session.id)
.first()
)
if event.id
else None
)
if existing_event:
continue
# Generate ID for the event if it doesn't exist
if not event.id:
event.id = str(uuid.uuid4())
# Create timestamp if it doesn't exist
if not event.timestamp:
event.timestamp = datetime.now().timestamp()
# Create StorageEvent object
storage_event = StorageEvent(
id=event.id,
app_name=session.app_name,
user_id=session.user_id,
session_id=session.id,
invocation_id=event.invocation_id or str(uuid.uuid4()),
author=event.author,
branch=event.branch,
timestamp=datetime.fromtimestamp(event.timestamp),
actions=event.actions or {},
long_running_tool_ids=event.long_running_tool_ids or set(),
grounding_metadata=event.grounding_metadata,
partial=event.partial,
turn_complete=event.turn_complete,
error_code=event.error_code,
error_message=event.error_message,
interrupted=event.interrupted,
)
# Encode content, if it exists
if event.content:
encoded_content = event.content.model_dump(exclude_none=True)
# Solution for serialization issues with multimedia content
for p in encoded_content.get("parts", []):
if "inline_data" in p:
p["inline_data"]["data"] = (
base64.b64encode(p["inline_data"]["data"]).decode(
"utf-8"
),
)
storage_event.content = encoded_content
db_session.add(storage_event)
# Commit changes
db_session.commit()
# Update timestamp in session
db_session.refresh(storage_session)
session.last_update_time = storage_session.update_time.timestamp()
logger.info(f"Session saved: {session.id} with {len(session.events)} events")
def list_sessions(self, agent_id: str, external_id: str) -> List[Dict[str, Any]]:
"""
Lists all sessions for an agent and user.
Args:
agent_id: Agent ID
external_id: User ID
Returns:
List[Dict[str, Any]]: List of summarized sessions
"""
with self.Session() as db_session:
sessions = (
db_session.query(StorageSession)
.filter(StorageSession.app_name == agent_id)
.filter(StorageSession.user_id == external_id)
.all()
)
result = []
for session in sessions:
result.append(
{
"app_name": session.app_name,
"user_id": session.user_id,
"id": session.id,
"created_at": session.create_time.isoformat(),
"updated_at": session.update_time.isoformat(),
}
)
return result
def delete_session(self, agent_id: str, external_id: str, session_id: str) -> bool:
"""
Deletes a session from the database.
Args:
agent_id: Agent ID
external_id: User ID
session_id: Session ID
Returns:
bool: True if the session was deleted, False otherwise
"""
from sqlalchemy import delete
with self.Session() as db_session:
stmt = delete(StorageSession).where(
StorageSession.app_name == agent_id,
StorageSession.user_id == external_id,
StorageSession.id == session_id,
)
result = db_session.execute(stmt)
db_session.commit()
logger.info(f"Session deleted: {session_id}")
return result.rowcount > 0
# Utility functions compatible with ADK
def _extract_state_delta(state: dict[str, Any]):
"""Extracts state deltas between app, user, and session."""
app_state_delta = {}
user_state_delta = {}
session_state_delta = {}
if state:
for key in state.keys():
if key.startswith(State.APP_PREFIX):
app_state_delta[key.removeprefix(State.APP_PREFIX)] = state[key]
elif key.startswith(State.USER_PREFIX):
user_state_delta[key.removeprefix(State.USER_PREFIX)] = state[key]
elif not key.startswith(State.TEMP_PREFIX):
session_state_delta[key] = state[key]
return app_state_delta, user_state_delta, session_state_delta
def _merge_state(app_state, user_state, session_state):
"""Merges app, user, and session states into a single object."""
merged_state = copy.deepcopy(session_state)
for key in app_state.keys():
merged_state[State.APP_PREFIX + key] = app_state[key]
for key in user_state.keys():
merged_state[State.USER_PREFIX + key] = user_state[key]
return merged_state
def _decode_content(content: Optional[dict[str, Any]]) -> Optional[Content]:
"""Decodes event content potentially with binary data."""
if not content:
return None
for p in content.get("parts", []):
if "inline_data" in p and isinstance(p["inline_data"].get("data"), tuple):
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
return Content.model_validate(content)

View File

@ -27,12 +27,22 @@
"""
from src.config.settings import settings
import os
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.sessions import DatabaseSessionService
from google.adk.memory import InMemoryMemoryService
from dotenv import load_dotenv
load_dotenv()
from src.services.crewai.session_service import CrewSessionService
if os.getenv("AI_ENGINE") == "crewai":
session_service = CrewSessionService(db_url=os.getenv("POSTGRES_CONNECTION_STRING"))
else:
session_service = DatabaseSessionService(
db_url=os.getenv("POSTGRES_CONNECTION_STRING")
)
# Initialize service instances
session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING)
artifacts_service = InMemoryArtifactService()
memory_service = InMemoryMemoryService()