chore: update project structure and add testing framework

This commit is contained in:
Davidson Gomes
2025-04-28 20:41:10 -03:00
parent 7af234ef48
commit e7e030dfd5
49 changed files with 1261 additions and 619 deletions

View File

@@ -1 +1 @@
from .agent_runner import run_agent
from .agent_runner import run_agent

View File

@@ -13,11 +13,10 @@ from google.adk.agents.callback_context import CallbackContext
from google.adk.models import LlmResponse, LlmRequest
from google.adk.tools import load_memory
from typing import Optional
import logging
import os
import requests
import os
from datetime import datetime
logger = setup_logger(__name__)
@@ -83,7 +82,7 @@ def before_model_callback(
llm_request.config.system_instruction = modified_text
logger.debug(
f"📝 System instruction updated with search results and history"
"📝 System instruction updated with search results and history"
)
else:
logger.warning("⚠️ No results found in the search")
@@ -180,11 +179,13 @@ class AgentBuilder:
mcp_tools = []
mcp_exit_stack = None
if agent.config.get("mcp_servers"):
mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(agent.config, self.db)
mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(
agent.config, self.db
)
# Combine all tools
all_tools = custom_tools + mcp_tools
now = datetime.now()
current_datetime = now.strftime("%d/%m/%Y %H:%M")
current_day_of_week = now.strftime("%A")
@@ -201,10 +202,13 @@ class AgentBuilder:
# Check if load_memory is enabled
# before_model_callback_func = None
if agent.config.get("load_memory") == True:
if agent.config.get("load_memory"):
all_tools.append(load_memory)
# before_model_callback_func = before_model_callback
formatted_prompt = formatted_prompt + "\n\n<memory_instructions>ALWAYS use the load_memory tool to retrieve knowledge for your context</memory_instructions>\n\n"
formatted_prompt = (
formatted_prompt
+ "\n\n<memory_instructions>ALWAYS use the load_memory tool to retrieve knowledge for your context</memory_instructions>\n\n"
)
return (
LlmAgent(

View File

@@ -22,9 +22,7 @@ async def run_agent(
db: Session,
):
try:
logger.info(
f"Starting execution of agent {agent_id} for contact {contact_id}"
)
logger.info(f"Starting execution of agent {agent_id} for contact {contact_id}")
logger.info(f"Received message: {message}")
get_root_agent = get_agent(db, agent_id)
@@ -77,15 +75,15 @@ async def run_agent(
if event.is_final_response() and event.content and event.content.parts:
final_response_text = event.content.parts[0].text
logger.info(f"Final response received: {final_response_text}")
completed_session = session_service.get_session(
app_name=agent_id,
user_id=contact_id,
session_id=session_id,
)
memory_service.add_session_to_memory(completed_session)
finally:
# Ensure the exit_stack is closed correctly
if exit_stack:

View File

@@ -216,9 +216,7 @@ async def update_agent(
return agent
except Exception as e:
db.rollback()
raise HTTPException(
status_code=500, detail=f"Error updating agent: {str(e)}"
)
raise HTTPException(status_code=500, detail=f"Error updating agent: {str(e)}")
def delete_agent(db: Session, agent_id: uuid.UUID) -> bool:

View File

@@ -1,15 +1,15 @@
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from src.models.models import AuditLog, User
from src.models.models import AuditLog
from datetime import datetime
from fastapi import Request
from typing import Optional, Dict, Any, List
import uuid
import logging
import json
logger = logging.getLogger(__name__)
def create_audit_log(
db: Session,
user_id: Optional[uuid.UUID],
@@ -17,11 +17,11 @@ def create_audit_log(
resource_type: str,
resource_id: Optional[str] = None,
details: Optional[Dict[str, Any]] = None,
request: Optional[Request] = None
request: Optional[Request] = None,
) -> Optional[AuditLog]:
"""
Create a new audit log
Args:
db: Database session
user_id: User ID that performed the action (or None if anonymous)
@@ -30,25 +30,25 @@ def create_audit_log(
resource_id: Resource ID (optional)
details: Additional details of the action (optional)
request: FastAPI Request object (optional, to get IP and User-Agent)
Returns:
Optional[AuditLog]: Created audit log or None in case of error
"""
try:
ip_address = None
user_agent = None
if request:
ip_address = request.client.host if hasattr(request, 'client') else None
ip_address = request.client.host if hasattr(request, "client") else None
user_agent = request.headers.get("user-agent")
# Convert details to serializable format
if details:
# Convert UUIDs to strings
for key, value in details.items():
if isinstance(value, uuid.UUID):
details[key] = str(value)
audit_log = AuditLog(
user_id=user_id,
action=action,
@@ -56,20 +56,20 @@ def create_audit_log(
resource_id=str(resource_id) if resource_id else None,
details=details,
ip_address=ip_address,
user_agent=user_agent
user_agent=user_agent,
)
db.add(audit_log)
db.commit()
db.refresh(audit_log)
logger.info(
f"Audit log created: {action} in {resource_type}" +
(f" (ID: {resource_id})" if resource_id else "")
f"Audit log created: {action} in {resource_type}"
+ (f" (ID: {resource_id})" if resource_id else "")
)
return audit_log
except SQLAlchemyError as e:
db.rollback()
logger.error(f"Error creating audit log: {str(e)}")
@@ -78,6 +78,7 @@ def create_audit_log(
logger.error(f"Unexpected error creating audit log: {str(e)}")
return None
def get_audit_logs(
db: Session,
skip: int = 0,
@@ -87,11 +88,11 @@ def get_audit_logs(
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
end_date: Optional[datetime] = None,
) -> List[AuditLog]:
"""
Get audit logs with optional filters
Args:
db: Database session
skip: Number of records to skip
@@ -102,35 +103,35 @@ def get_audit_logs(
resource_id: Filter by resource ID
start_date: Start date
end_date: End date
Returns:
List[AuditLog]: List of audit logs
"""
query = db.query(AuditLog)
# Apply filters, if provided
if user_id:
query = query.filter(AuditLog.user_id == user_id)
if action:
query = query.filter(AuditLog.action == action)
if resource_type:
query = query.filter(AuditLog.resource_type == resource_type)
if resource_id:
query = query.filter(AuditLog.resource_id == resource_id)
if start_date:
query = query.filter(AuditLog.created_at >= start_date)
if end_date:
query = query.filter(AuditLog.created_at <= end_date)
# Order by creation date (most recent first)
query = query.order_by(AuditLog.created_at.desc())
# Apply pagination
query = query.offset(skip).limit(limit)
return query.all()
return query.all()

View File

@@ -16,17 +16,20 @@ logger = logging.getLogger(__name__)
# Define OAuth2 authentication scheme with password flow
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> User:
async def get_current_user(
token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)
) -> User:
"""
Get the current user from the JWT token
Args:
token: JWT token
db: Database session
Returns:
User: Current user
Raises:
HTTPException: If the token is invalid or the user is not found
"""
@@ -35,103 +38,108 @@ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = De
detail="Invalid credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
# Decode the token
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM]
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
# Extract token data
email: str = payload.get("sub")
if email is None:
logger.warning("Token without email (sub)")
raise credentials_exception
# Check if the token has expired
exp = payload.get("exp")
if exp is None or datetime.fromtimestamp(exp) < datetime.utcnow():
logger.warning(f"Token expired for {email}")
raise credentials_exception
# Create TokenData object
token_data = TokenData(
sub=email,
exp=datetime.fromtimestamp(exp),
is_admin=payload.get("is_admin", False),
client_id=payload.get("client_id")
client_id=payload.get("client_id"),
)
except JWTError as e:
logger.error(f"Error decoding JWT token: {str(e)}")
raise credentials_exception
# Search for user in the database
user = get_user_by_email(db, email=token_data.sub)
if user is None:
logger.warning(f"User not found for email: {token_data.sub}")
raise credentials_exception
if not user.is_active:
logger.warning(f"Attempt to access inactive user: {user.email}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return user
async def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
async def get_current_active_user(
current_user: User = Depends(get_current_user),
) -> User:
"""
Check if the current user is active
Args:
current_user: Current user
Returns:
User: Current user if active
Raises:
HTTPException: If the user is not active
"""
if not current_user.is_active:
logger.warning(f"Attempt to access inactive user: {current_user.email}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return current_user
async def get_current_admin_user(current_user: User = Depends(get_current_user)) -> User:
async def get_current_admin_user(
current_user: User = Depends(get_current_user),
) -> User:
"""
Check if the current user is an administrator
Args:
current_user: Current user
Returns:
User: Current user if administrator
Raises:
HTTPException: If the user is not an administrator
"""
if not current_user.is_admin:
logger.warning(f"Attempt to access admin by non-admin user: {current_user.email}")
logger.warning(
f"Attempt to access admin by non-admin user: {current_user.email}"
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. Restricted to administrators."
detail="Access denied. Restricted to administrators.",
)
return current_user
def create_access_token(user: User) -> str:
"""
Create a JWT access token for the user
Args:
user: User for which to create the token
Returns:
str: JWT token
"""
@@ -140,10 +148,10 @@ def create_access_token(user: User) -> str:
"sub": user.email,
"is_admin": user.is_admin,
}
# Include client_id only if not administrator and client_id is set
if not user.is_admin and user.client_id:
token_data["client_id"] = str(user.client_id)
# Create token
return create_jwt_token(token_data)
return create_jwt_token(token_data)

View File

@@ -11,6 +11,7 @@ import logging
logger = logging.getLogger(__name__)
def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]:
"""Search for a client by ID"""
try:
@@ -23,9 +24,10 @@ def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]:
logger.error(f"Error searching for client {client_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for client"
detail="Error searching for client",
)
def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]:
"""Search for all clients with pagination"""
try:
@@ -34,9 +36,10 @@ def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]:
logger.error(f"Error searching for clients: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for clients"
detail="Error searching for clients",
)
def create_client(db: Session, client: ClientCreate) -> Client:
"""Create a new client"""
try:
@@ -51,19 +54,22 @@ def create_client(db: Session, client: ClientCreate) -> Client:
logger.error(f"Error creating client: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating client"
detail="Error creating client",
)
def update_client(db: Session, client_id: uuid.UUID, client: ClientCreate) -> Optional[Client]:
def update_client(
db: Session, client_id: uuid.UUID, client: ClientCreate
) -> Optional[Client]:
"""Updates an existing client"""
try:
db_client = get_client(db, client_id)
if not db_client:
return None
for key, value in client.model_dump().items():
setattr(db_client, key, value)
db.commit()
db.refresh(db_client)
logger.info(f"Client updated successfully: {client_id}")
@@ -73,16 +79,17 @@ def update_client(db: Session, client_id: uuid.UUID, client: ClientCreate) -> Op
logger.error(f"Error updating client {client_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating client"
detail="Error updating client",
)
def delete_client(db: Session, client_id: uuid.UUID) -> bool:
"""Removes a client"""
try:
db_client = get_client(db, client_id)
if not db_client:
return False
db.delete(db_client)
db.commit()
logger.info(f"Client removed successfully: {client_id}")
@@ -92,18 +99,21 @@ def delete_client(db: Session, client_id: uuid.UUID) -> bool:
logger.error(f"Error removing client {client_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error removing client"
detail="Error removing client",
)
def create_client_with_user(db: Session, client_data: ClientCreate, user_data: UserCreate) -> Tuple[Optional[Client], str]:
def create_client_with_user(
db: Session, client_data: ClientCreate, user_data: UserCreate
) -> Tuple[Optional[Client], str]:
"""
Creates a new client with an associated user
Args:
db: Database session
client_data: Client data to be created
user_data: User data to be created
Returns:
Tuple[Optional[Client], str]: Tuple with the created client (or None in case of error) and status message
"""
@@ -112,27 +122,27 @@ def create_client_with_user(db: Session, client_data: ClientCreate, user_data: U
client = Client(**client_data.model_dump())
db.add(client)
db.flush() # Get client ID without committing the transaction
# Use client ID to create the associated user
user, message = create_user(db, user_data, is_admin=False, client_id=client.id)
if not user:
# If there was an error creating the user, rollback
db.rollback()
logger.error(f"Error creating user for client: {message}")
return None, f"Error creating user: {message}"
# If everything went well, commit the transaction
db.commit()
logger.info(f"Client and user created successfully: {client.id}")
return client, "Client and user created successfully"
except SQLAlchemyError as e:
db.rollback()
logger.error(f"Error creating client with user: {str(e)}")
return None, f"Error creating client with user: {str(e)}"
except Exception as e:
db.rollback()
logger.error(f"Unexpected error creating client with user: {str(e)}")
return None, f"Unexpected error: {str(e)}"
return None, f"Unexpected error: {str(e)}"

View File

@@ -9,6 +9,7 @@ import logging
logger = logging.getLogger(__name__)
def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]:
"""Search for a contact by ID"""
try:
@@ -21,20 +22,30 @@ def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]:
logger.error(f"Error searching for contact {contact_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for contact"
detail="Error searching for contact",
)
def get_contacts_by_client(db: Session, client_id: uuid.UUID, skip: int = 0, limit: int = 100) -> List[Contact]:
def get_contacts_by_client(
db: Session, client_id: uuid.UUID, skip: int = 0, limit: int = 100
) -> List[Contact]:
"""Search for contacts of a client with pagination"""
try:
return db.query(Contact).filter(Contact.client_id == client_id).offset(skip).limit(limit).all()
return (
db.query(Contact)
.filter(Contact.client_id == client_id)
.offset(skip)
.limit(limit)
.all()
)
except SQLAlchemyError as e:
logger.error(f"Error searching for contacts of client {client_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for contacts"
detail="Error searching for contacts",
)
def create_contact(db: Session, contact: ContactCreate) -> Contact:
"""Create a new contact"""
try:
@@ -49,19 +60,22 @@ def create_contact(db: Session, contact: ContactCreate) -> Contact:
logger.error(f"Error creating contact: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating contact"
detail="Error creating contact",
)
def update_contact(db: Session, contact_id: uuid.UUID, contact: ContactCreate) -> Optional[Contact]:
def update_contact(
db: Session, contact_id: uuid.UUID, contact: ContactCreate
) -> Optional[Contact]:
"""Update an existing contact"""
try:
db_contact = get_contact(db, contact_id)
if not db_contact:
return None
for key, value in contact.model_dump().items():
setattr(db_contact, key, value)
db.commit()
db.refresh(db_contact)
logger.info(f"Contact updated successfully: {contact_id}")
@@ -71,16 +85,17 @@ def update_contact(db: Session, contact_id: uuid.UUID, contact: ContactCreate) -
logger.error(f"Error updating contact {contact_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating contact"
detail="Error updating contact",
)
def delete_contact(db: Session, contact_id: uuid.UUID) -> bool:
"""Remove a contact"""
try:
db_contact = get_contact(db, contact_id)
if not db_contact:
return False
db.delete(db_contact)
db.commit()
logger.info(f"Contact removed successfully: {contact_id}")
@@ -90,5 +105,5 @@ def delete_contact(db: Session, contact_id: uuid.UUID) -> bool:
logger.error(f"Error removing contact {contact_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error removing contact"
)
detail="Error removing contact",
)

View File

@@ -6,6 +6,7 @@ from src.utils.logger import setup_logger
logger = setup_logger(__name__)
class CustomToolBuilder:
def __init__(self):
self.tools = []
@@ -53,7 +54,9 @@ class CustomToolBuilder:
# Adds default values to query params if they are not present
for param, value in values.items():
if param not in query_params and param not in parameters.get("path_params", {}):
if param not in query_params and param not in parameters.get(
"path_params", {}
):
query_params[param] = value
# Processa body parameters
@@ -64,7 +67,11 @@ class CustomToolBuilder:
# Adds default values to body if they are not present
for param, value in values.items():
if param not in body_data and param not in query_params and param not in parameters.get("path_params", {}):
if (
param not in body_data
and param not in query_params
and param not in parameters.get("path_params", {})
):
body_data[param] = value
# Makes the HTTP request
@@ -74,7 +81,7 @@ class CustomToolBuilder:
headers=processed_headers,
params=query_params,
json=body_data if body_data else None,
timeout=error_handling.get("timeout", 30)
timeout=error_handling.get("timeout", 30),
)
if response.status_code >= 400:
@@ -87,30 +94,34 @@ class CustomToolBuilder:
except Exception as e:
logger.error(f"Error executing tool {name}: {str(e)}")
return json.dumps(error_handling.get("fallback_response", {
"error": "tool_execution_error",
"message": str(e)
}))
return json.dumps(
error_handling.get(
"fallback_response",
{"error": "tool_execution_error", "message": str(e)},
)
)
# Adds dynamic docstring based on the configuration
param_docs = []
# Adds path parameters
for param, value in parameters.get("path_params", {}).items():
param_docs.append(f"{param}: {value}")
# Adds query parameters
for param, value in parameters.get("query_params", {}).items():
if isinstance(value, list):
param_docs.append(f"{param}: List[{', '.join(value)}]")
else:
param_docs.append(f"{param}: {value}")
# Adds body parameters
for param, param_config in parameters.get("body_params", {}).items():
required = "Required" if param_config.get("required", False) else "Optional"
param_docs.append(f"{param} ({param_config['type']}, {required}): {param_config['description']}")
param_docs.append(
f"{param} ({param_config['type']}, {required}): {param_config['description']}"
)
# Adds default values
if values:
param_docs.append("\nDefault values:")
@@ -119,10 +130,10 @@ class CustomToolBuilder:
http_tool.__doc__ = f"""
{description}
Parameters:
{chr(10).join(param_docs)}
Returns:
String containing the response in JSON format
"""
@@ -140,4 +151,4 @@ class CustomToolBuilder:
for http_tool_config in tools_config.get("http_tools", []):
self.tools.append(self._create_http_tool(http_tool_config))
return self.tools
return self.tools

View File

@@ -16,17 +16,18 @@ os.makedirs(templates_dir, exist_ok=True)
# Configure Jinja2 with the templates directory
env = Environment(
loader=FileSystemLoader(templates_dir),
autoescape=select_autoescape(['html', 'xml'])
autoescape=select_autoescape(["html", "xml"]),
)
def _render_template(template_name: str, context: dict) -> str:
"""
Render a template with the provided data
Args:
template_name: Template file name
context: Data to render in the template
Returns:
str: Rendered HTML
"""
@@ -37,14 +38,15 @@ def _render_template(template_name: str, context: dict) -> str:
logger.error(f"Error rendering template '{template_name}': {str(e)}")
return f"<p>Could not display email content. Please access {context.get('verification_link', '') or context.get('reset_link', '')}</p>"
def send_verification_email(email: str, token: str) -> bool:
"""
Send a verification email to the user
Args:
email: Recipient's email
token: Email verification token
Returns:
bool: True if the email was sent successfully, False otherwise
"""
@@ -53,39 +55,47 @@ def send_verification_email(email: str, token: str) -> bool:
from_email = Email(settings.EMAIL_FROM)
to_email = To(email)
subject = "Email Verification - Evo AI"
verification_link = f"{settings.APP_URL}/auth/verify-email/{token}"
html_content = _render_template('verification_email', {
'verification_link': verification_link,
'user_name': email.split('@')[0], # Use part of the email as temporary name
'current_year': datetime.now().year
})
html_content = _render_template(
"verification_email",
{
"verification_link": verification_link,
"user_name": email.split("@")[
0
], # Use part of the email as temporary name
"current_year": datetime.now().year,
},
)
content = Content("text/html", html_content)
mail = Mail(from_email, to_email, subject, content)
response = sg.client.mail.send.post(request_body=mail.get())
if response.status_code >= 200 and response.status_code < 300:
logger.info(f"Verification email sent to {email}")
return True
else:
logger.error(f"Failed to send verification email to {email}. Status: {response.status_code}")
logger.error(
f"Failed to send verification email to {email}. Status: {response.status_code}"
)
return False
except Exception as e:
logger.error(f"Error sending verification email to {email}: {str(e)}")
return False
def send_password_reset_email(email: str, token: str) -> bool:
"""
Send a password reset email to the user
Args:
email: Recipient's email
token: Password reset token
Returns:
bool: True if the email was sent successfully, False otherwise
"""
@@ -94,39 +104,47 @@ def send_password_reset_email(email: str, token: str) -> bool:
from_email = Email(settings.EMAIL_FROM)
to_email = To(email)
subject = "Password Reset - Evo AI"
reset_link = f"{settings.APP_URL}/reset-password?token={token}"
html_content = _render_template('password_reset', {
'reset_link': reset_link,
'user_name': email.split('@')[0], # Use part of the email as temporary name
'current_year': datetime.now().year
})
html_content = _render_template(
"password_reset",
{
"reset_link": reset_link,
"user_name": email.split("@")[
0
], # Use part of the email as temporary name
"current_year": datetime.now().year,
},
)
content = Content("text/html", html_content)
mail = Mail(from_email, to_email, subject, content)
response = sg.client.mail.send.post(request_body=mail.get())
if response.status_code >= 200 and response.status_code < 300:
logger.info(f"Password reset email sent to {email}")
return True
else:
logger.error(f"Failed to send password reset email to {email}. Status: {response.status_code}")
logger.error(
f"Failed to send password reset email to {email}. Status: {response.status_code}"
)
return False
except Exception as e:
logger.error(f"Error sending password reset email to {email}: {str(e)}")
return False
def send_welcome_email(email: str, user_name: str = None) -> bool:
"""
Send a welcome email to the user after verification
Args:
email: Recipient's email
user_name: User's name (optional)
Returns:
bool: True if the email was sent successfully, False otherwise
"""
@@ -135,41 +153,49 @@ def send_welcome_email(email: str, user_name: str = None) -> bool:
from_email = Email(settings.EMAIL_FROM)
to_email = To(email)
subject = "Welcome to Evo AI"
dashboard_link = f"{settings.APP_URL}/dashboard"
html_content = _render_template('welcome_email', {
'dashboard_link': dashboard_link,
'user_name': user_name or email.split('@')[0],
'current_year': datetime.now().year
})
html_content = _render_template(
"welcome_email",
{
"dashboard_link": dashboard_link,
"user_name": user_name or email.split("@")[0],
"current_year": datetime.now().year,
},
)
content = Content("text/html", html_content)
mail = Mail(from_email, to_email, subject, content)
response = sg.client.mail.send.post(request_body=mail.get())
if response.status_code >= 200 and response.status_code < 300:
logger.info(f"Welcome email sent to {email}")
return True
else:
logger.error(f"Failed to send welcome email to {email}. Status: {response.status_code}")
logger.error(
f"Failed to send welcome email to {email}. Status: {response.status_code}"
)
return False
except Exception as e:
logger.error(f"Error sending welcome email to {email}: {str(e)}")
return False
def send_account_locked_email(email: str, reset_token: str, failed_attempts: int, time_period: str) -> bool:
def send_account_locked_email(
email: str, reset_token: str, failed_attempts: int, time_period: str
) -> bool:
"""
Send an email informing that the account has been locked after login attempts
Args:
email: Recipient's email
reset_token: Token to reset the password
failed_attempts: Number of failed attempts
time_period: Time period of the attempts
Returns:
bool: True if the email was sent successfully, False otherwise
"""
@@ -178,29 +204,34 @@ def send_account_locked_email(email: str, reset_token: str, failed_attempts: int
from_email = Email(settings.EMAIL_FROM)
to_email = To(email)
subject = "Security Alert - Account Locked"
reset_link = f"{settings.APP_URL}/reset-password?token={reset_token}"
html_content = _render_template('account_locked', {
'reset_link': reset_link,
'user_name': email.split('@')[0],
'failed_attempts': failed_attempts,
'time_period': time_period,
'current_year': datetime.now().year
})
html_content = _render_template(
"account_locked",
{
"reset_link": reset_link,
"user_name": email.split("@")[0],
"failed_attempts": failed_attempts,
"time_period": time_period,
"current_year": datetime.now().year,
},
)
content = Content("text/html", html_content)
mail = Mail(from_email, to_email, subject, content)
response = sg.client.mail.send.post(request_body=mail.get())
if response.status_code >= 200 and response.status_code < 300:
logger.info(f"Account locked email sent to {email}")
return True
else:
logger.error(f"Failed to send account locked email to {email}. Status: {response.status_code}")
logger.error(
f"Failed to send account locked email to {email}. Status: {response.status_code}"
)
return False
except Exception as e:
logger.error(f"Error sending account locked email to {email}: {str(e)}")
return False
return False

View File

@@ -9,6 +9,7 @@ import logging
logger = logging.getLogger(__name__)
def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]:
"""Search for an MCP server by ID"""
try:
@@ -21,9 +22,10 @@ def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]:
logger.error(f"Error searching for MCP server {server_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for MCP server"
detail="Error searching for MCP server",
)
def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPServer]:
"""Search for all MCP servers with pagination"""
try:
@@ -32,9 +34,10 @@ def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPSer
logger.error(f"Error searching for MCP servers: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for MCP servers"
detail="Error searching for MCP servers",
)
def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer:
"""Create a new MCP server"""
try:
@@ -49,19 +52,22 @@ def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer:
logger.error(f"Error creating MCP server: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating MCP server"
detail="Error creating MCP server",
)
def update_mcp_server(db: Session, server_id: uuid.UUID, server: MCPServerCreate) -> Optional[MCPServer]:
def update_mcp_server(
db: Session, server_id: uuid.UUID, server: MCPServerCreate
) -> Optional[MCPServer]:
"""Update an existing MCP server"""
try:
db_server = get_mcp_server(db, server_id)
if not db_server:
return None
for key, value in server.model_dump().items():
setattr(db_server, key, value)
db.commit()
db.refresh(db_server)
logger.info(f"MCP server updated successfully: {server_id}")
@@ -71,16 +77,17 @@ def update_mcp_server(db: Session, server_id: uuid.UUID, server: MCPServerCreate
logger.error(f"Error updating MCP server {server_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating MCP server"
detail="Error updating MCP server",
)
def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool:
"""Remove an MCP server"""
try:
db_server = get_mcp_server(db, server_id)
if not db_server:
return False
db.delete(db_server)
db.commit()
logger.info(f"MCP server removed successfully: {server_id}")
@@ -90,5 +97,5 @@ def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool:
logger.error(f"Error removing MCP server {server_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error removing MCP server"
)
detail="Error removing MCP server",
)

View File

@@ -12,26 +12,28 @@ from sqlalchemy.orm import Session
logger = setup_logger(__name__)
class MCPService:
def __init__(self):
self.tools = []
self.exit_stack = AsyncExitStack()
async def _connect_to_mcp_server(self, server_config: Dict[str, Any]) -> Tuple[List[Any], Optional[AsyncExitStack]]:
async def _connect_to_mcp_server(
self, server_config: Dict[str, Any]
) -> Tuple[List[Any], Optional[AsyncExitStack]]:
"""Connect to a specific MCP server and return its tools."""
try:
# Determines the type of server (local or remote)
if "url" in server_config:
# Remote server (SSE)
connection_params = SseServerParams(
url=server_config["url"],
headers=server_config.get("headers", {})
url=server_config["url"], headers=server_config.get("headers", {})
)
else:
# Local server (Stdio)
command = server_config.get("command", "npx")
args = server_config.get("args", [])
# Adds environment variables if specified
env = server_config.get("env", {})
if env:
@@ -39,9 +41,7 @@ class MCPService:
os.environ[key] = value
connection_params = StdioServerParameters(
command=command,
args=args,
env=env
command=command, args=args, env=env
)
tools, exit_stack = await MCPToolset.from_server(
@@ -73,8 +73,10 @@ class MCPService:
logger.warning(f"Removed {removed_count} incompatible tools.")
return filtered_tools
def _filter_tools_by_agent(self, tools: List[Any], agent_tools: List[str]) -> List[Any]:
def _filter_tools_by_agent(
self, tools: List[Any], agent_tools: List[str]
) -> List[Any]:
"""Filters tools compatible with the agent."""
filtered_tools = []
for tool in tools:
@@ -83,7 +85,9 @@ class MCPService:
filtered_tools.append(tool)
return filtered_tools
async def build_tools(self, mcp_config: Dict[str, Any], db: Session) -> Tuple[List[Any], AsyncExitStack]:
async def build_tools(
self, mcp_config: Dict[str, Any], db: Session
) -> Tuple[List[Any], AsyncExitStack]:
"""Builds a list of tools from multiple MCP servers."""
self.tools = []
self.exit_stack = AsyncExitStack()
@@ -92,23 +96,25 @@ class MCPService:
for server in mcp_config.get("mcp_servers", []):
try:
# Search for the MCP server in the database
mcp_server = get_mcp_server(db, server['id'])
mcp_server = get_mcp_server(db, server["id"])
if not mcp_server:
logger.warning(f"Servidor MCP não encontrado: {server['id']}")
continue
# Prepares the server configuration
server_config = mcp_server.config_json.copy()
# Replaces the environment variables in the config_json
if 'env' in server_config:
for key, value in server_config['env'].items():
if value.startswith('env@@'):
env_key = value.replace('env@@', '')
if env_key in server.get('envs', {}):
server_config['env'][key] = server['envs'][env_key]
if "env" in server_config:
for key, value in server_config["env"].items():
if value.startswith("env@@"):
env_key = value.replace("env@@", "")
if env_key in server.get("envs", {}):
server_config["env"][key] = server["envs"][env_key]
else:
logger.warning(f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}")
logger.warning(
f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}"
)
continue
logger.info(f"Connecting to MCP server: {mcp_server.name}")
@@ -117,22 +123,30 @@ class MCPService:
if tools and exit_stack:
# Filters incompatible tools
filtered_tools = self._filter_incompatible_tools(tools)
# Filters tools compatible with the agent
agent_tools = server.get('tools', [])
filtered_tools = self._filter_tools_by_agent(filtered_tools, agent_tools)
agent_tools = server.get("tools", [])
filtered_tools = self._filter_tools_by_agent(
filtered_tools, agent_tools
)
self.tools.extend(filtered_tools)
# Registers the exit_stack with the AsyncExitStack
await self.exit_stack.enter_async_context(exit_stack)
logger.info(f"Connected successfully. Added {len(filtered_tools)} tools.")
logger.info(
f"Connected successfully. Added {len(filtered_tools)} tools."
)
else:
logger.warning(f"Failed to connect or no tools available for {mcp_server.name}")
logger.warning(
f"Failed to connect or no tools available for {mcp_server.name}"
)
except Exception as e:
logger.error(f"Error connecting to MCP server {server['id']}: {e}")
continue
logger.info(f"MCP Toolset created successfully. Total of {len(self.tools)} tools.")
logger.info(
f"MCP Toolset created successfully. Total of {len(self.tools)} tools."
)
return self.tools, self.exit_stack
return self.tools, self.exit_stack

View File

@@ -0,0 +1,9 @@
from src.config.settings import settings
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.sessions import DatabaseSessionService
from google.adk.memory import InMemoryMemoryService
# Initialize service instances
session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING)
artifacts_service = InMemoryArtifactService()
memory_service = InMemoryMemoryService()

View File

@@ -66,7 +66,7 @@ def get_session_by_id(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid session ID. Expected format: app_name_user_id",
)
parts = session_id.split("_", 1)
if len(parts) != 2:
logger.error(f"Invalid session ID format: {session_id}")
@@ -74,22 +74,22 @@ def get_session_by_id(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid session ID format. Expected format: app_name_user_id",
)
user_id, app_name = parts
session = session_service.get_session(
app_name=app_name,
user_id=user_id,
session_id=session_id,
)
if session is None:
logger.error(f"Session not found: {session_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Session not found: {session_id}",
)
return session
except Exception as e:
logger.error(f"Error searching for session {session_id}: {str(e)}")
@@ -106,7 +106,7 @@ def delete_session(session_service: DatabaseSessionService, session_id: str) ->
try:
session = get_session_by_id(session_service, session_id)
# If we get here, the session exists (get_session_by_id already validates)
session_service.delete_session(
app_name=session.app_name,
user_id=session.user_id,
@@ -131,10 +131,10 @@ def get_session_events(
try:
session = get_session_by_id(session_service, session_id)
# If we get here, the session exists (get_session_by_id already validates)
if not hasattr(session, 'events') or session.events is None:
if not hasattr(session, "events") or session.events is None:
return []
return session.events
except HTTPException:
# Passes HTTP exceptions from get_session_by_id

View File

@@ -9,6 +9,7 @@ import logging
logger = logging.getLogger(__name__)
def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]:
"""Search for a tool by ID"""
try:
@@ -21,9 +22,10 @@ def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]:
logger.error(f"Error searching for tool {tool_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for tool"
detail="Error searching for tool",
)
def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]:
"""Search for all tools with pagination"""
try:
@@ -32,9 +34,10 @@ def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]:
logger.error(f"Error searching for tools: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for tools"
detail="Error searching for tools",
)
def create_tool(db: Session, tool: ToolCreate) -> Tool:
"""Creates a new tool"""
try:
@@ -49,19 +52,20 @@ def create_tool(db: Session, tool: ToolCreate) -> Tool:
logger.error(f"Error creating tool: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating tool"
detail="Error creating tool",
)
def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[Tool]:
"""Updates an existing tool"""
try:
db_tool = get_tool(db, tool_id)
if not db_tool:
return None
for key, value in tool.model_dump().items():
setattr(db_tool, key, value)
db.commit()
db.refresh(db_tool)
logger.info(f"Tool updated successfully: {tool_id}")
@@ -71,16 +75,17 @@ def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[T
logger.error(f"Error updating tool {tool_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating tool"
detail="Error updating tool",
)
def delete_tool(db: Session, tool_id: uuid.UUID) -> bool:
"""Remove a tool"""
try:
db_tool = get_tool(db, tool_id)
if not db_tool:
return False
db.delete(db_tool)
db.commit()
logger.info(f"Tool removed successfully: {tool_id}")
@@ -90,5 +95,5 @@ def delete_tool(db: Session, tool_id: uuid.UUID) -> bool:
logger.error(f"Error removing tool {tool_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error removing tool"
)
detail="Error removing tool",
)

View File

@@ -3,7 +3,10 @@ from sqlalchemy.exc import SQLAlchemyError
from src.models.models import User, Client
from src.schemas.user import UserCreate
from src.utils.security import get_password_hash, verify_password, generate_token
from src.services.email_service import send_verification_email, send_password_reset_email
from src.services.email_service import (
send_verification_email,
send_password_reset_email,
)
from datetime import datetime, timedelta
import uuid
import logging
@@ -11,16 +14,22 @@ from typing import Optional, Tuple
logger = logging.getLogger(__name__)
def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, client_id: Optional[uuid.UUID] = None) -> Tuple[Optional[User], str]:
def create_user(
db: Session,
user_data: UserCreate,
is_admin: bool = False,
client_id: Optional[uuid.UUID] = None,
) -> Tuple[Optional[User], str]:
"""
Creates a new user in the system
Creates a new user in the system
Args:
db: Database session
user_data: User data to be created
is_admin: If the user is an administrator
client_id: Associated client ID (optional, a new one will be created if not provided)
Returns:
Tuple[Optional[User], str]: Tuple with the created user (or None in case of error) and status message
"""
@@ -28,17 +37,19 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
# Check if email already exists
db_user = db.query(User).filter(User.email == user_data.email).first()
if db_user:
logger.warning(f"Attempt to register with existing email: {user_data.email}")
logger.warning(
f"Attempt to register with existing email: {user_data.email}"
)
return None, "Email already registered"
# Create verification token
verification_token = generate_token()
token_expiry = datetime.utcnow() + timedelta(hours=24)
# Start transaction
user = None
local_client_id = client_id
try:
# If not admin and no client_id, create an associated client
if not is_admin and local_client_id is None:
@@ -46,7 +57,7 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
db.add(client)
db.flush() # Get the client ID
local_client_id = client.id
# Create user
user = User(
email=user_data.email,
@@ -56,52 +67,56 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
is_active=False, # Inactive until email is verified
email_verified=False,
verification_token=verification_token,
verification_token_expiry=token_expiry
verification_token_expiry=token_expiry,
)
db.add(user)
db.commit()
# Send verification email
email_sent = send_verification_email(user.email, verification_token)
if not email_sent:
logger.error(f"Failed to send verification email to {user.email}")
# We don't do rollback here, we just log the error
logger.info(f"User created successfully: {user.email}")
return user, "User created successfully. Check your email to activate your account."
return (
user,
"User created successfully. Check your email to activate your account.",
)
except SQLAlchemyError as e:
db.rollback()
logger.error(f"Error creating user: {str(e)}")
return None, f"Error creating user: {str(e)}"
except Exception as e:
logger.error(f"Unexpected error creating user: {str(e)}")
return None, f"Unexpected error: {str(e)}"
def verify_email(db: Session, token: str) -> Tuple[bool, str]:
"""
Verify the user's email using the provided token
Args:
db: Database session
token: Verification token
Returns:
Tuple[bool, str]: Tuple with verification status and message
"""
try:
# Search for user by token
user = db.query(User).filter(User.verification_token == token).first()
if not user:
logger.warning(f"Attempt to verify with invalid token: {token}")
return False, "Invalid verification token"
# Check if the token has expired
now = datetime.utcnow()
expiry = user.verification_token_expiry
# Ensure both dates are of the same type (aware or naive)
if expiry.tzinfo is not None and now.tzinfo is None:
# If expiry has timezone and now doesn't, convert now to have timezone
@@ -109,180 +124,201 @@ def verify_email(db: Session, token: str) -> Tuple[bool, str]:
elif now.tzinfo is not None and expiry.tzinfo is None:
# If now has timezone and expiry doesn't, convert expiry to have timezone
expiry = expiry.replace(tzinfo=now.tzinfo)
if expiry < now:
logger.warning(f"Attempt to verify with expired token for user: {user.email}")
logger.warning(
f"Attempt to verify with expired token for user: {user.email}"
)
return False, "Verification token expired"
# Update user
user.email_verified = True
user.is_active = True
user.verification_token = None
user.verification_token_expiry = None
db.commit()
logger.info(f"Email verified successfully for user: {user.email}")
return True, "Email verified successfully. Your account is active."
except SQLAlchemyError as e:
db.rollback()
logger.error(f"Error verifying email: {str(e)}")
return False, f"Error verifying email: {str(e)}"
except Exception as e:
logger.error(f"Unexpected error verifying email: {str(e)}")
return False, f"Unexpected error: {str(e)}"
def resend_verification(db: Session, email: str) -> Tuple[bool, str]:
"""
Resend the verification email
Args:
db: Database session
email: User email
Returns:
Tuple[bool, str]: Tuple with operation status and message
"""
try:
# Search for user by email
user = db.query(User).filter(User.email == email).first()
if not user:
logger.warning(f"Attempt to resend verification email for non-existent email: {email}")
logger.warning(
f"Attempt to resend verification email for non-existent email: {email}"
)
return False, "Email not found"
if user.email_verified:
logger.info(f"Attempt to resend verification email for already verified email: {email}")
logger.info(
f"Attempt to resend verification email for already verified email: {email}"
)
return False, "Email already verified"
# Generate new token
verification_token = generate_token()
token_expiry = datetime.utcnow() + timedelta(hours=24)
# Update user
user.verification_token = verification_token
user.verification_token_expiry = token_expiry
db.commit()
# Send email
email_sent = send_verification_email(user.email, verification_token)
if not email_sent:
logger.error(f"Failed to resend verification email to {user.email}")
return False, "Failed to send verification email"
logger.info(f"Verification email resent successfully to: {user.email}")
return True, "Verification email resent. Check your inbox."
except SQLAlchemyError as e:
db.rollback()
logger.error(f"Error resending verification: {str(e)}")
return False, f"Error resending verification: {str(e)}"
except Exception as e:
logger.error(f"Unexpected error resending verification: {str(e)}")
return False, f"Unexpected error: {str(e)}"
def forgot_password(db: Session, email: str) -> Tuple[bool, str]:
"""
Initiates the password recovery process
Args:
db: Database session
email: User email
Returns:
Tuple[bool, str]: Tuple with operation status and message
"""
try:
# Search for user by email
user = db.query(User).filter(User.email == email).first()
if not user:
# For security, we don't inform if the email exists or not
logger.info(f"Attempt to recover password for non-existent email: {email}")
return True, "If the email is registered, you will receive instructions to reset your password."
return (
True,
"If the email is registered, you will receive instructions to reset your password.",
)
# Generate reset token
reset_token = generate_token()
token_expiry = datetime.utcnow() + timedelta(hours=1) # Token valid for 1 hour
# Update user
user.password_reset_token = reset_token
user.password_reset_expiry = token_expiry
db.commit()
# Send email
email_sent = send_password_reset_email(user.email, reset_token)
if not email_sent:
logger.error(f"Failed to send password reset email to {user.email}")
return False, "Failed to send password reset email"
logger.info(f"Password reset email sent successfully to: {user.email}")
return True, "If the email is registered, you will receive instructions to reset your password."
return (
True,
"If the email is registered, you will receive instructions to reset your password.",
)
except SQLAlchemyError as e:
db.rollback()
logger.error(f"Error processing password recovery: {str(e)}")
return False, f"Error processing password recovery: {str(e)}"
except Exception as e:
logger.error(f"Unexpected error processing password recovery: {str(e)}")
return False, f"Unexpected error: {str(e)}"
def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, str]:
"""
Resets the user's password using the provided token
Args:
db: Database session
token: Password reset token
new_password: New password
Returns:
Tuple[bool, str]: Tuple with operation status and message
"""
try:
# Search for user by token
user = db.query(User).filter(User.password_reset_token == token).first()
if not user:
logger.warning(f"Attempt to reset password with invalid token: {token}")
return False, "Invalid password reset token"
# Check if the token has expired
if user.password_reset_expiry < datetime.utcnow():
logger.warning(f"Attempt to reset password with expired token for user: {user.email}")
logger.warning(
f"Attempt to reset password with expired token for user: {user.email}"
)
return False, "Password reset token expired"
# Update password
user.password_hash = get_password_hash(new_password)
user.password_reset_token = None
user.password_reset_expiry = None
db.commit()
logger.info(f"Password reset successfully for user: {user.email}")
return True, "Password reset successfully. You can now login with your new password."
return (
True,
"Password reset successfully. You can now login with your new password.",
)
except SQLAlchemyError as e:
db.rollback()
logger.error(f"Error resetting password: {str(e)}")
return False, f"Error resetting password: {str(e)}"
except Exception as e:
logger.error(f"Unexpected error resetting password: {str(e)}")
return False, f"Unexpected error: {str(e)}"
def get_user_by_email(db: Session, email: str) -> Optional[User]:
"""
Searches for a user by email
Args:
db: Database session
email: User email
Returns:
Optional[User]: User found or None
"""
@@ -292,15 +328,16 @@ def get_user_by_email(db: Session, email: str) -> Optional[User]:
logger.error(f"Error searching for user by email: {str(e)}")
return None
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
"""
Authenticates a user with email and password
Args:
db: Database session
email: User email
password: User password
Returns:
Optional[User]: Authenticated user or None
"""
@@ -313,75 +350,78 @@ def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
return None
return user
def get_admin_users(db: Session, skip: int = 0, limit: int = 100):
"""
Lists the admin users
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
List[User]: List of admin users
"""
try:
users = db.query(User).filter(User.is_admin == True).offset(skip).limit(limit).all()
users = db.query(User).filter(User.is_admin).offset(skip).limit(limit).all()
logger.info(f"List of admins: {len(users)} found")
return users
except SQLAlchemyError as e:
logger.error(f"Error listing admins: {str(e)}")
return []
except Exception as e:
logger.error(f"Unexpected error listing admins: {str(e)}")
return []
def create_admin_user(db: Session, user_data: UserCreate) -> Tuple[Optional[User], str]:
"""
Creates a new admin user
Args:
db: Database session
user_data: User data to be created
Returns:
Tuple[Optional[User], str]: Tuple with the created user (or None in case of error) and status message
"""
return create_user(db, user_data, is_admin=True)
def deactivate_user(db: Session, user_id: uuid.UUID) -> Tuple[bool, str]:
"""
Deactivates a user (does not delete, only marks as inactive)
Args:
db: Database session
user_id: ID of the user to be deactivated
Returns:
Tuple[bool, str]: Tuple with operation status and message
"""
try:
# Search for user by ID
user = db.query(User).filter(User.id == user_id).first()
if not user:
logger.warning(f"Attempt to deactivate non-existent user: {user_id}")
return False, "User not found"
# Deactivate user
user.is_active = False
db.commit()
logger.info(f"User deactivated successfully: {user.email}")
return True, "User deactivated successfully"
except SQLAlchemyError as e:
db.rollback()
logger.error(f"Error deactivating user: {str(e)}")
return False, f"Error deactivating user: {str(e)}"
except Exception as e:
logger.error(f"Unexpected error deactivating user: {str(e)}")
return False, f"Unexpected error: {str(e)}"
return False, f"Unexpected error: {str(e)}"