chore: update project structure and add testing framework

This commit is contained in:
Davidson Gomes 2025-04-28 20:41:10 -03:00
parent 7af234ef48
commit e7e030dfd5
49 changed files with 1261 additions and 619 deletions

View File

@ -17,9 +17,16 @@
```
src/
├── api/
│ ├── routes.py # API routes definition
│ ├── auth_routes.py # Authentication routes (login, registration, etc.)
│ └── admin_routes.py # Protected admin routes
│ ├── __init__.py # Package initialization
│ ├── admin_routes.py # Admin routes for management interface
│ ├── agent_routes.py # Routes to manage agents
│ ├── auth_routes.py # Authentication routes (login, registration)
│ ├── chat_routes.py # Routes for chat interactions with agents
│ ├── client_routes.py # Routes to manage clients
│ ├── contact_routes.py # Routes to manage contacts
│ ├── mcp_server_routes.py # Routes to manage MCP servers
│ ├── session_routes.py # Routes to manage chat sessions
│ └── tool_routes.py # Routes to manage tools for agents
├── config/
│ ├── database.py # Database configuration
│ └── settings.py # General settings
@ -30,18 +37,21 @@ src/
│ └── models.py # SQLAlchemy models
├── schemas/
│ ├── schemas.py # Main Pydantic schemas
│ ├── chat.py # Chat schemas
│ ├── user.py # User and authentication schemas
│ └── audit.py # Audit logs schemas
├── services/
│ ├── agent_service.py # Business logic for agents
│ ├── agent_runner.py # Agent execution logic
│ ├── auth_service.py # JWT authentication logic
│ ├── audit_service.py # Audit logs logic
│ ├── client_service.py # Business logic for clients
│ ├── contact_service.py # Business logic for contacts
│ ├── mcp_server_service.py # Business logic for MCP servers
│ ├── tool_service.py # Business logic for tools
│ ├── user_service.py # User and authentication logic
│ ├── auth_service.py # JWT authentication logic
│ ├── email_service.py # Email sending service
│ └── audit_service.py # Audit logs logic
│ ├── mcp_server_service.py # Business logic for MCP servers
│ ├── session_service.py # Business logic for chat sessions
│ ├── tool_service.py # Business logic for tools
│ └── user_service.py # User management logic
├── templates/
│ ├── emails/
│ │ ├── base_email.html # Base template with common structure and styles
@ -49,7 +59,21 @@ src/
│ │ ├── password_reset.html # Password reset template
│ │ ├── welcome_email.html # Welcome email after verification
│ │ └── account_locked.html # Security alert for locked accounts
├── tests/
│ ├── __init__.py # Package initialization
│ ├── api/
│ │ ├── __init__.py # Package initialization
│ │ ├── test_auth_routes.py # Test for authentication routes
│ │ └── test_root.py # Test for root endpoint
│ ├── models/
│ │ ├── __init__.py # Package initialization
│ │ ├── test_models.py # Test for models
│ ├── services/
│ │ ├── __init__.py # Package initialization
│ │ ├── test_auth_service.py # Test for authentication service
│ │ └── test_user_service.py # Test for user service
└── utils/
├── logger.py # Logger configuration
└── security.py # Security utilities (JWT, hash)
```
@ -63,6 +87,15 @@ src/
- Code examples in documentation must be in English
- Commit messages must be in English
### Project Configuration
- Dependencies managed in `pyproject.toml` using modern Python packaging standards
- Development dependencies specified as optional dependencies in `pyproject.toml`
- Single source of truth for project metadata in `pyproject.toml`
- Build system configured to use setuptools
- Pytest configuration in `pyproject.toml` under `[tool.pytest.ini_options]`
- Code formatting with Black configured in `pyproject.toml`
- Linting with Flake8 configured in `.flake8`
### Schemas (Pydantic)
- Use `BaseModel` as base for all schemas
- Define fields with explicit types
@ -136,6 +169,28 @@ src/
- Indentation with 4 spaces
- Maximum of 79 characters per line
## Commit Rules
- Use Conventional Commits format for all commit messages
- Format: `<type>(<scope>): <description>`
- Types:
- `feat`: A new feature
- `fix`: A bug fix
- `docs`: Documentation changes
- `style`: Changes that do not affect code meaning (formatting, etc.)
- `refactor`: Code changes that neither fix a bug nor add a feature
- `perf`: Performance improvements
- `test`: Adding or modifying tests
- `chore`: Changes to build process or auxiliary tools
- Scope is optional and should be the module or component affected
- Description should be concise, in the imperative mood, and not capitalized
- Use body for more detailed explanations if needed
- Reference issues in the footer with `Fixes #123` or `Relates to #123`
- Examples:
- `feat(auth): add password reset functionality`
- `fix(api): correct validation error in client registration`
- `docs: update API documentation for new endpoints`
- `refactor(services): improve error handling in authentication`
## Best Practices
- Always validate input data
- Implement appropriate logging
@ -163,6 +218,7 @@ src/
## Useful Commands
- `make run`: Start the server
- `make run-prod`: Start the server in production mode
- `make alembic-revision message="description"`: Create new migration
- `make alembic-upgrade`: Apply pending migrations
- `make alembic-downgrade`: Revert last migration
@ -170,3 +226,8 @@ src/
- `make alembic-reset`: Reset database to initial state
- `make alembic-upgrade-cascade`: Force upgrade removing dependencies
- `make clear-cache`: Clean project cache
- `make seed-all`: Run all database seeders
- `make lint`: Run linting checks with flake8
- `make format`: Format code with black
- `make install`: Install project for development
- `make install-dev`: Install project with development dependencies

View File

@ -1,3 +1,47 @@
# Environment and IDE
.venv
venv
.env
.idea
.vscode
__pycache__
*.pyc
*.pyo
*.pyd
.Python
.pytest_cache
.coverage
htmlcov/
.tox/
# Version control
.git
.github
.gitignore
# Logs and temp files
logs
*.log
tmp
.DS_Store
# Docker
.dockerignore
Dockerfile*
docker-compose*
# Documentation
README.md
LICENSE
docs/
# Development tools
tests/
.flake8
pyproject.toml
requirements-dev.txt
Makefile
# Ambiente virtual
venv/
__pycache__/

8
.flake8 Normal file
View File

@ -0,0 +1,8 @@
[flake8]
max-line-length = 88
exclude = .git,__pycache__,venv,alembic/versions/*
ignore = E203, W503, E501
per-file-ignores =
__init__.py: F401
src/models/models.py: E712
alembic/*: E711,E712,F401

77
.gitignore vendored
View File

@ -1,13 +1,10 @@
# Byte-compiled / optimized / DLL files
# Python
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
@ -19,19 +16,10 @@ lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
@ -42,14 +30,57 @@ nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# IDE
.idea/
.vscode/
*.sublime-project
*.sublime-workspace
.DS_Store
# Logs
logs/
*.log
# Database
*.db
*.sqlite
*.sqlite3
backup/
# Local
local_settings.py
local.py
# Docker
.docker/
# Alembic versions
# alembic/versions/
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
@ -77,17 +108,6 @@ celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
.venv/
.env/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
@ -102,11 +122,8 @@ venv.bak/
.mypy_cache/
# IDE
.idea/
.vscode/
*.swp
*.swo
# OS
.DS_Store
Thumbs.db

View File

@ -1,4 +1,4 @@
FROM python:3.11-slim
FROM python:3.10-slim
# Define o diretório de trabalho
WORKDIR /app
@ -15,19 +15,19 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Copia os arquivos de requisitos
COPY requirements.txt .
# Instala as dependências
RUN pip install --no-cache-dir -r requirements.txt
# Copia o código-fonte
# Copy project files
COPY . .
# Install dependencies
RUN pip install --no-cache-dir -e .
# Configuração para produção
ENV PORT=8000 \
HOST=0.0.0.0 \
DEBUG=false
# Expose port
EXPOSE 8000
# Define o comando de inicialização
CMD alembic upgrade head && uvicorn src.main:app --host $HOST --port $PORT

View File

@ -1,42 +1,42 @@
.PHONY: migrate init revision upgrade downgrade run seed-admin seed-client seed-agents seed-mcp-servers seed-tools seed-contacts seed-all docker-build docker-up docker-down docker-logs
.PHONY: migrate init revision upgrade downgrade run seed-admin seed-client seed-agents seed-mcp-servers seed-tools seed-contacts seed-all docker-build docker-up docker-down docker-logs lint format install install-dev venv
# Comandos do Alembic
# Alembic commands
init:
alembic init alembics
# make alembic-revision message="descrição da migração"
# make alembic-revision message="migration description"
alembic-revision:
alembic revision --autogenerate -m "$(message)"
# Comando para atualizar o banco de dados
# Command to update database to latest version
alembic-upgrade:
alembic upgrade head
# Comando para voltar uma versão
# Command to downgrade one version
alembic-downgrade:
alembic downgrade -1
# Comando para rodar o servidor
# Command to run the server
run:
uvicorn src.main:app --reload --host 0.0.0.0 --port 8000 --reload-dir src
# Comando para limpar o cache em todas as pastas do projeto pastas pycache
# Command to run the server in production mode
run-prod:
uvicorn src.main:app --host 0.0.0.0 --port 8000 --workers 4
# Command to clean cache in all project folders
clear-cache:
rm -rf ~/.cache/uv/environments-v2/* && find . -type d -name "__pycache__" -exec rm -r {} +
# Comando para criar uma nova migração
# Command to create a new migration and apply it
alembic-migrate:
alembic revision --autogenerate -m "$(message)" && alembic upgrade head
# Comando para resetar o banco de dados
# Command to reset the database
alembic-reset:
alembic downgrade base && alembic upgrade head
# Comando para forçar upgrade com CASCADE
alembic-upgrade-cascade:
psql -U postgres -d a2a_saas -c "DROP TABLE IF EXISTS events CASCADE; DROP TABLE IF EXISTS sessions CASCADE; DROP TABLE IF EXISTS user_states CASCADE; DROP TABLE IF EXISTS app_states CASCADE;" && alembic upgrade head
# Comandos para executar seeders
# Commands to run seeders
seed-admin:
python -m scripts.seeders.admin_seeder
@ -58,7 +58,7 @@ seed-contacts:
seed-all:
python -m scripts.run_seeders
# Comandos Docker
# Docker commands
docker-build:
docker-compose build
@ -73,3 +73,20 @@ docker-logs:
docker-seed:
docker-compose exec api python -m scripts.run_seeders
# Testing, linting and formatting commands
lint:
flake8 src/ tests/
format:
black src/ tests/
# Virtual environment and installation commands
venv:
python -m venv venv
install:
pip install -e .
install-dev:
pip install -e ".[dev]"

53
conftest.py Normal file
View File

@ -0,0 +1,53 @@
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from src.config.database import Base, get_db
from src.main import app
# Use in-memory SQLite for tests
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture(scope="function")
def db_session():
"""Creates a fresh database session for each test."""
Base.metadata.create_all(bind=engine) # Create tables
connection = engine.connect()
transaction = connection.begin()
session = TestingSessionLocal(bind=connection)
# Use our test database instead of the standard one
def override_get_db():
try:
yield session
session.commit()
finally:
session.close()
app.dependency_overrides[get_db] = override_get_db
yield session # The test will run here
# Teardown
transaction.rollback()
connection.close()
Base.metadata.drop_all(bind=engine)
app.dependency_overrides.clear()
@pytest.fixture(scope="function")
def client(db_session):
"""Creates a FastAPI TestClient with database session fixture."""
with TestClient(app) as test_client:
yield test_client

89
pyproject.toml Normal file
View File

@ -0,0 +1,89 @@
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "evo-ai"
version = "1.0.0"
description = "API for executing AI agents"
readme = "README.md"
authors = [
{name = "EvoAI Team", email = "admin@evoai.com"}
]
requires-python = ">=3.10"
license = {text = "Proprietary"}
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"License :: Other/Proprietary License",
"Operating System :: OS Independent",
]
# Main dependencies
dependencies = [
"fastapi==0.115.12",
"uvicorn==0.34.2",
"pydantic==2.11.3",
"sqlalchemy==2.0.40",
"psycopg2-binary==2.9.10",
"google-cloud-aiplatform==1.90.0",
"python-dotenv==1.1.0",
"google-adk==0.3.0",
"litellm==1.67.4.post1",
"python-multipart==0.0.20",
"alembic==1.15.2",
"asyncpg==0.30.0",
"python-jose==3.4.0",
"passlib==1.7.4",
"sendgrid==6.11.0",
"pydantic-settings==2.9.1",
"fastapi_utils==0.8.0",
"bcrypt==4.3.0",
"jinja2==3.1.6",
]
[project.optional-dependencies]
dev = [
"black==25.1.0",
"flake8==7.2.0",
"pytest==8.3.5",
"pytest-cov==6.1.1",
"httpx==0.28.1",
"pytest-asyncio==0.26.0",
]
[tool.setuptools]
packages = ["src"]
[tool.black]
line-length = 88
target-version = ["py310"]
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
| venv
| alembic/versions
)/
'''
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = "test_*.py"
python_functions = "test_*"
python_classes = "Test*"
filterwarnings = [
"ignore::DeprecationWarning",
]
[tool.coverage.run]
source = ["src"]
omit = ["tests/*", "alembic/*"]

View File

@ -1,22 +0,0 @@
fastapi
uvicorn
pydantic
sqlalchemy
psycopg2
psycopg2-binary
google-cloud-aiplatform
python-dotenv
google-adk
litellm
python-multipart
alembic
asyncpg
# Novas dependências para autenticação
python-jose[cryptography]
passlib[bcrypt]
sendgrid
pydantic[email]
pydantic-settings
fastapi_utils
bcrypt
jinja2

6
setup.py Normal file
View File

@ -0,0 +1,6 @@
"""Setup script for the package."""
from setuptools import setup
if __name__ == "__main__":
setup()

View File

@ -1,3 +1,3 @@
"""
Pacote principal da aplicação
Main package of the application
"""

View File

@ -1,14 +1,17 @@
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
from typing import List, Optional
from datetime import datetime
import uuid
from src.config.database import get_db
from src.core.jwt_middleware import get_jwt_token, verify_admin
from src.schemas.audit import AuditLogResponse, AuditLogFilter
from src.services.audit_service import get_audit_logs, create_audit_log
from src.services.user_service import get_admin_users, create_admin_user, deactivate_user
from src.services.user_service import (
get_admin_users,
create_admin_user,
deactivate_user,
)
from src.schemas.user import UserResponse, AdminUserCreate
router = APIRouter(
@ -18,6 +21,7 @@ router = APIRouter(
responses={403: {"description": "Permission denied"}},
)
# Audit routes
@router.get("/audit-logs", response_model=List[AuditLogResponse])
async def read_audit_logs(
@ -45,9 +49,10 @@ async def read_audit_logs(
resource_type=filters.resource_type,
resource_id=filters.resource_id,
start_date=filters.start_date,
end_date=filters.end_date
end_date=filters.end_date,
)
# Admin routes
@router.get("/users", response_model=List[UserResponse])
async def read_admin_users(
@ -70,6 +75,7 @@ async def read_admin_users(
"""
return get_admin_users(db, skip, limit)
@router.post("/users", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def create_new_admin_user(
user_data: AdminUserCreate,
@ -97,16 +103,13 @@ async def create_new_admin_user(
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unable to identify the logged in user"
detail="Unable to identify the logged in user",
)
# Create admin user
user, message = create_admin_user(db, user_data)
if not user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=message
)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message)
# Register action in audit log
create_audit_log(
@ -116,11 +119,12 @@ async def create_new_admin_user(
resource_type="admin_user",
resource_id=str(user.id),
details={"email": user.email},
request=request
request=request,
)
return user
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def deactivate_admin_user(
user_id: uuid.UUID,
@ -145,23 +149,20 @@ async def deactivate_admin_user(
if not current_user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unable to identify the logged in user"
detail="Unable to identify the logged in user",
)
# Do not allow deactivating yourself
if str(user_id) == current_user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Unable to deactivate your own user"
detail="Unable to deactivate your own user",
)
# Deactivate user
success, message = deactivate_user(db, user_id)
if not success:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=message
)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message)
# Register action in audit log
create_audit_log(
@ -171,5 +172,5 @@ async def deactivate_admin_user(
resource_type="admin_user",
resource_id=str(user_id),
details=None,
request=request
request=request,
)

View File

@ -7,10 +7,6 @@ from src.core.jwt_middleware import (
get_jwt_token,
verify_user_client,
)
from src.core.jwt_middleware import (
get_jwt_token,
verify_user_client,
)
from src.schemas.schemas import (
Agent,
AgentCreate,

View File

@ -162,9 +162,7 @@ async def login_for_access_token(form_data: UserLogin, db: Session = Depends(get
"""
user = authenticate_user(db, form_data.email, form_data.password)
if not user:
logger.warning(
f"Login attempt with invalid credentials: {form_data.email}"
)
logger.warning(f"Login attempt with invalid credentials: {form_data.email}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",

View File

@ -11,7 +11,11 @@ from src.services import (
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse
from src.services.agent_runner import run_agent
from src.core.exceptions import AgentNotFoundError
from src.main import session_service, artifacts_service, memory_service
from src.services.service_providers import (
session_service,
artifacts_service,
memory_service,
)
from datetime import datetime
import logging

View File

@ -19,7 +19,7 @@ from src.services.session_service import (
get_sessions_by_agent,
get_sessions_by_client,
)
from src.main import session_service
from src.services.service_providers import session_service
import logging
logger = logging.getLogger(__name__)
@ -30,6 +30,7 @@ router = APIRouter(
responses={404: {"description": "Not found"}},
)
# Session Routes
@router.get("/client/{client_id}", response_model=List[Adk_Session])
async def get_client_sessions(

View File

@ -10,6 +10,7 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db():
db = SessionLocal()
try:

View File

@ -1,55 +1,66 @@
from fastapi import HTTPException
from typing import Optional, Dict, Any
class BaseAPIException(HTTPException):
"""Base class for API exceptions"""
def __init__(
self,
status_code: int,
message: str,
error_code: str,
details: Optional[Dict[str, Any]] = None
details: Optional[Dict[str, Any]] = None,
):
super().__init__(status_code=status_code, detail={
super().__init__(
status_code=status_code,
detail={
"error": message,
"error_code": error_code,
"details": details or {}
})
"details": details or {},
},
)
class AgentNotFoundError(BaseAPIException):
"""Exception when the agent is not found"""
def __init__(self, agent_id: str):
super().__init__(
status_code=404,
message=f"Agent with ID {agent_id} not found",
error_code="AGENT_NOT_FOUND"
error_code="AGENT_NOT_FOUND",
)
class InvalidParameterError(BaseAPIException):
"""Exception for invalid parameters"""
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
super().__init__(
status_code=400,
message=message,
error_code="INVALID_PARAMETER",
details=details
details=details,
)
class InvalidRequestError(BaseAPIException):
"""Exception for invalid requests"""
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
super().__init__(
status_code=400,
message=message,
error_code="INVALID_REQUEST",
details=details
details=details,
)
class InternalServerError(BaseAPIException):
"""Exception for server errors"""
def __init__(self, message: str = "Server error"):
super().__init__(
status_code=500,
message=message,
error_code="INTERNAL_SERVER_ERROR"
status_code=500, message=message, error_code="INTERNAL_SERVER_ERROR"
)

View File

@ -13,6 +13,7 @@ logger = logging.getLogger(__name__)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict:
"""
Extracts and validates the JWT token
@ -34,9 +35,7 @@ async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict:
try:
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM]
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
email: str = payload.get("sub")
@ -55,10 +54,11 @@ async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict:
logger.error(f"Error decoding JWT token: {str(e)}")
raise credentials_exception
async def verify_user_client(
payload: dict = Depends(get_jwt_token),
db: Session = Depends(get_db),
required_client_id: UUID = None
required_client_id: UUID = None,
) -> bool:
"""
Verifies if the user is associated with the specified client
@ -81,10 +81,12 @@ async def verify_user_client(
# Para não-admins, verificar se o client_id corresponde
user_client_id = payload.get("client_id")
if not user_client_id:
logger.warning(f"Non-admin user without client_id in token: {payload.get('sub')}")
logger.warning(
f"Non-admin user without client_id in token: {payload.get('sub')}"
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User not associated with a client"
detail="User not associated with a client",
)
# If no client_id is specified to verify, any client is valid
@ -93,14 +95,17 @@ async def verify_user_client(
# Verify if the user's client_id corresponds to the required_client_id
if str(user_client_id) != str(required_client_id):
logger.warning(f"Access denied: User {payload.get('sub')} tried to access resources of client {required_client_id}")
logger.warning(
f"Access denied: User {payload.get('sub')} tried to access resources of client {required_client_id}"
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to access resources of this client"
detail="Access denied to access resources of this client",
)
return True
async def verify_admin(payload: dict = Depends(get_jwt_token)) -> bool:
"""
Verifies if the user is an administrator
@ -118,12 +123,15 @@ async def verify_admin(payload: dict = Depends(get_jwt_token)) -> bool:
logger.warning(f"Access denied to admin: User {payload.get('sub')}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. Restricted to administrators."
detail="Access denied. Restricted to administrators.",
)
return True
def get_current_user_client_id(payload: dict = Depends(get_jwt_token)) -> Optional[UUID]:
def get_current_user_client_id(
payload: dict = Depends(get_jwt_token),
) -> Optional[UUID]:
"""
Gets the ID of the client associated with the current user

View File

@ -1,35 +1,30 @@
import os
import sys
from pathlib import Path
# Add the root directory to PYTHONPATH
root_dir = Path(__file__).parent.parent
sys.path.append(str(root_dir))
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from src.config.database import engine, Base
from src.config.settings import settings
from src.utils.logger import setup_logger
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.sessions import DatabaseSessionService
from google.adk.memory import InMemoryMemoryService
# Initialize service instances
session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING)
artifacts_service = InMemoryArtifactService()
memory_service = InMemoryMemoryService()
# Necessary for other modules
from src.services.service_providers import session_service # noqa: F401
from src.services.service_providers import artifacts_service # noqa: F401
from src.services.service_providers import memory_service # noqa: F401
# Import routers after service initialization to avoid circular imports
from src.api.auth_routes import router as auth_router
from src.api.admin_routes import router as admin_router
from src.api.chat_routes import router as chat_router
from src.api.session_routes import router as session_router
from src.api.agent_routes import router as agent_router
from src.api.contact_routes import router as contact_router
from src.api.mcp_server_routes import router as mcp_server_router
from src.api.tool_routes import router as tool_router
from src.api.client_routes import router as client_router
import src.api.auth_routes
import src.api.admin_routes
import src.api.chat_routes
import src.api.session_routes
import src.api.agent_routes
import src.api.contact_routes
import src.api.mcp_server_routes
import src.api.tool_routes
import src.api.client_routes
# Add the root directory to PYTHONPATH
root_dir = Path(__file__).parent.parent
sys.path.append(str(root_dir))
# Configure logger
logger = setup_logger(__name__)
@ -52,8 +47,7 @@ app.add_middleware(
# PostgreSQL configuration
POSTGRES_CONNECTION_STRING = os.getenv(
"POSTGRES_CONNECTION_STRING",
"postgresql://postgres:root@localhost:5432/evo_ai"
"POSTGRES_CONNECTION_STRING", "postgresql://postgres:root@localhost:5432/evo_ai"
)
# Create database tables
@ -61,6 +55,17 @@ Base.metadata.create_all(bind=engine)
API_PREFIX = "/api/v1"
# Define router references
auth_router = src.api.auth_routes.router
admin_router = src.api.admin_routes.router
chat_router = src.api.chat_routes.router
session_router = src.api.session_routes.router
agent_router = src.api.agent_routes.router
contact_router = src.api.contact_routes.router
mcp_server_router = src.api.mcp_server_routes.router
tool_router = src.api.tool_routes.router
client_router = src.api.client_routes.router
# Include routes
app.include_router(auth_router, prefix=API_PREFIX)
app.include_router(admin_router, prefix=API_PREFIX)
@ -79,5 +84,5 @@ def read_root():
"message": "Welcome to Evo AI API",
"documentation": "/docs",
"version": settings.API_VERSION,
"auth": "To access the API, use JWT authentication via '/api/v1/auth/login'"
"auth": "To access the API, use JWT authentication via '/api/v1/auth/login'",
}

View File

@ -1,9 +1,20 @@
from sqlalchemy import Column, String, UUID, DateTime, ForeignKey, JSON, Text, BigInteger, CheckConstraint, Boolean
from sqlalchemy import (
Column,
String,
UUID,
DateTime,
ForeignKey,
JSON,
Text,
CheckConstraint,
Boolean,
)
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship, backref
from src.config.database import Base
import uuid
class Client(Base):
__tablename__ = "clients"
@ -13,13 +24,16 @@ class Client(Base):
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
class User(Base):
__tablename__ = "users"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
email = Column(String, unique=True, index=True, nullable=False)
password_hash = Column(String, nullable=False)
client_id = Column(UUID(as_uuid=True), ForeignKey("clients.id", ondelete="CASCADE"), nullable=True)
client_id = Column(
UUID(as_uuid=True), ForeignKey("clients.id", ondelete="CASCADE"), nullable=True
)
is_active = Column(Boolean, default=False)
is_admin = Column(Boolean, default=False)
email_verified = Column(Boolean, default=False)
@ -31,7 +45,10 @@ class User(Base):
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
# Relationship with Client (One-to-One, optional for administrators)
client = relationship("Client", backref=backref("user", uselist=False, cascade="all, delete-orphan"))
client = relationship(
"Client", backref=backref("user", uselist=False, cascade="all, delete-orphan")
)
class Contact(Base):
__tablename__ = "contacts"
@ -44,6 +61,7 @@ class Contact(Base):
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
class Agent(Base):
__tablename__ = "agents"
@ -60,21 +78,30 @@ class Agent(Base):
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
__table_args__ = (
CheckConstraint("type IN ('llm', 'sequential', 'parallel', 'loop')", name='check_agent_type'),
CheckConstraint(
"type IN ('llm', 'sequential', 'parallel', 'loop')", name="check_agent_type"
),
)
def to_dict(self):
"""Converts the object to a dictionary, converting UUIDs to strings"""
result = {}
for key, value in self.__dict__.items():
if key.startswith('_'):
if key.startswith("_"):
continue
if isinstance(value, uuid.UUID):
result[key] = str(value)
elif isinstance(value, dict):
result[key] = self._convert_dict(value)
elif isinstance(value, list):
result[key] = [self._convert_dict(item) if isinstance(item, dict) else str(item) if isinstance(item, uuid.UUID) else item for item in value]
result[key] = [
(
self._convert_dict(item)
if isinstance(item, dict)
else str(item) if isinstance(item, uuid.UUID) else item
)
for item in value
]
else:
result[key] = value
return result
@ -88,11 +115,19 @@ class Agent(Base):
elif isinstance(value, dict):
result[key] = self._convert_dict(value)
elif isinstance(value, list):
result[key] = [self._convert_dict(item) if isinstance(item, dict) else str(item) if isinstance(item, uuid.UUID) else item for item in value]
result[key] = [
(
self._convert_dict(item)
if isinstance(item, dict)
else str(item) if isinstance(item, uuid.UUID) else item
)
for item in value
]
else:
result[key] = value
return result
class MCPServer(Base):
__tablename__ = "mcp_servers"
@ -107,9 +142,12 @@ class MCPServer(Base):
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
__table_args__ = (
CheckConstraint("type IN ('official', 'community')", name='check_mcp_server_type'),
CheckConstraint(
"type IN ('official', 'community')", name="check_mcp_server_type"
),
)
class Tool(Base):
__tablename__ = "tools"
@ -121,10 +159,11 @@ class Tool(Base):
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
class Session(Base):
__tablename__ = "sessions"
# The directive below makes Alembic ignore this table in migrations
__table_args__ = {'extend_existing': True, 'info': {'skip_autogenerate': True}}
__table_args__ = {"extend_existing": True, "info": {"skip_autogenerate": True}}
id = Column(String, primary_key=True)
app_name = Column(String)
@ -133,11 +172,14 @@ class Session(Base):
create_time = Column(DateTime(timezone=True))
update_time = Column(DateTime(timezone=True))
class AuditLog(Base):
__tablename__ = "audit_logs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
user_id = Column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
action = Column(String, nullable=False)
resource_type = Column(String, nullable=False)
resource_id = Column(String, nullable=True)

View File

@ -1,26 +1,38 @@
from typing import List, Optional, Dict, Any, Union
from typing import List, Optional, Dict, Union
from pydantic import BaseModel, Field
from uuid import UUID
class ToolConfig(BaseModel):
"""Configuration of a tool"""
id: UUID
envs: Dict[str, str] = Field(default_factory=dict, description="Environment variables of the tool")
envs: Dict[str, str] = Field(
default_factory=dict, description="Environment variables of the tool"
)
class Config:
from_attributes = True
class MCPServerConfig(BaseModel):
"""Configuration of an MCP server"""
id: UUID
envs: Dict[str, str] = Field(default_factory=dict, description="Environment variables of the server")
tools: List[str] = Field(default_factory=list, description="List of tools of the server")
envs: Dict[str, str] = Field(
default_factory=dict, description="Environment variables of the server"
)
tools: List[str] = Field(
default_factory=list, description="List of tools of the server"
)
class Config:
from_attributes = True
class HTTPToolParameter(BaseModel):
"""Parameter of an HTTP tool"""
type: str
required: bool
description: str
@ -28,8 +40,10 @@ class HTTPToolParameter(BaseModel):
class Config:
from_attributes = True
class HTTPToolParameters(BaseModel):
"""Parameters of an HTTP tool"""
path_params: Optional[Dict[str, str]] = None
query_params: Optional[Dict[str, Union[str, List[str]]]] = None
body_params: Optional[Dict[str, HTTPToolParameter]] = None
@ -37,8 +51,10 @@ class HTTPToolParameters(BaseModel):
class Config:
from_attributes = True
class HTTPToolErrorHandling(BaseModel):
"""Configuration of error handling"""
timeout: int
retry_count: int
fallback_response: Dict[str, str]
@ -46,8 +62,10 @@ class HTTPToolErrorHandling(BaseModel):
class Config:
from_attributes = True
class HTTPTool(BaseModel):
"""Configuration of an HTTP tool"""
name: str
method: str
values: Dict[str, str]
@ -60,42 +78,72 @@ class HTTPTool(BaseModel):
class Config:
from_attributes = True
class CustomTools(BaseModel):
"""Configuration of custom tools"""
http_tools: List[HTTPTool] = Field(default_factory=list, description="List of HTTP tools")
http_tools: List[HTTPTool] = Field(
default_factory=list, description="List of HTTP tools"
)
class Config:
from_attributes = True
class LLMConfig(BaseModel):
"""Configuration for LLM agents"""
tools: Optional[List[ToolConfig]] = Field(default=None, description="List of available tools")
custom_tools: Optional[CustomTools] = Field(default=None, description="Custom tools")
mcp_servers: Optional[List[MCPServerConfig]] = Field(default=None, description="List of MCP servers")
sub_agents: Optional[List[UUID]] = Field(default=None, description="List of IDs of sub-agents")
tools: Optional[List[ToolConfig]] = Field(
default=None, description="List of available tools"
)
custom_tools: Optional[CustomTools] = Field(
default=None, description="Custom tools"
)
mcp_servers: Optional[List[MCPServerConfig]] = Field(
default=None, description="List of MCP servers"
)
sub_agents: Optional[List[UUID]] = Field(
default=None, description="List of IDs of sub-agents"
)
class Config:
from_attributes = True
class SequentialConfig(BaseModel):
"""Configuration for sequential agents"""
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents in execution order")
sub_agents: List[UUID] = Field(
..., description="List of IDs of sub-agents in execution order"
)
class Config:
from_attributes = True
class ParallelConfig(BaseModel):
"""Configuration for parallel agents"""
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents for parallel execution")
sub_agents: List[UUID] = Field(
..., description="List of IDs of sub-agents for parallel execution"
)
class Config:
from_attributes = True
class LoopConfig(BaseModel):
"""Configuration for loop agents"""
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents for loop execution")
max_iterations: Optional[int] = Field(default=None, description="Maximum number of iterations")
condition: Optional[str] = Field(default=None, description="Condition to stop the loop")
sub_agents: List[UUID] = Field(
..., description="List of IDs of sub-agents for loop execution"
)
max_iterations: Optional[int] = Field(
default=None, description="Maximum number of iterations"
)
condition: Optional[str] = Field(
default=None, description="Condition to stop the loop"
)
class Config:
from_attributes = True

View File

@ -3,19 +3,25 @@ from typing import Optional, Dict, Any
from datetime import datetime
from uuid import UUID
class AuditLogBase(BaseModel):
"""Base schema for audit log"""
action: str
resource_type: str
resource_id: Optional[str] = None
details: Optional[Dict[str, Any]] = None
class AuditLogCreate(AuditLogBase):
"""Schema for creating audit log"""
pass
class AuditLogResponse(AuditLogBase):
"""Schema for audit log response"""
id: UUID
user_id: Optional[UUID] = None
ip_address: Optional[str] = None
@ -25,8 +31,10 @@ class AuditLogResponse(AuditLogBase):
class Config:
from_attributes = True
class AuditLogFilter(BaseModel):
"""Schema for audit log search filters"""
user_id: Optional[UUID] = None
action: Optional[str] = None
resource_type: Optional[str] = None

View File

@ -1,21 +1,33 @@
from pydantic import BaseModel, Field
from typing import Dict, Any, Optional
class ChatRequest(BaseModel):
"""Schema for chat requests"""
agent_id: str = Field(..., description="ID of the agent that will process the message")
contact_id: str = Field(..., description="ID of the contact that will process the message")
agent_id: str = Field(
..., description="ID of the agent that will process the message"
)
contact_id: str = Field(
..., description="ID of the contact that will process the message"
)
message: str = Field(..., description="User message")
class ChatResponse(BaseModel):
"""Schema for chat responses"""
response: str = Field(..., description="Agent response")
status: str = Field(..., description="Operation status")
error: Optional[str] = Field(None, description="Error message, if there is one")
timestamp: str = Field(..., description="Timestamp of the response")
class ErrorResponse(BaseModel):
"""Schema for error responses"""
error: str = Field(..., description="Error message")
status_code: int = Field(..., description="HTTP status code of the error")
details: Optional[Dict[str, Any]] = Field(None, description="Additional error details")
details: Optional[Dict[str, Any]] = Field(
None, description="Additional error details"
)

View File

@ -4,15 +4,18 @@ from datetime import datetime
from uuid import UUID
import uuid
import re
from .agent_config import LLMConfig, SequentialConfig, ParallelConfig, LoopConfig
from src.schemas.agent_config import LLMConfig
class ClientBase(BaseModel):
name: str
email: Optional[EmailStr] = None
class ClientCreate(ClientBase):
pass
class Client(ClientBase):
id: UUID
created_at: datetime
@ -20,14 +23,17 @@ class Client(ClientBase):
class Config:
from_attributes = True
class ContactBase(BaseModel):
ext_id: Optional[str] = None
name: Optional[str] = None
meta: Optional[Dict[str, Any]] = Field(default_factory=dict)
class ContactCreate(ContactBase):
client_id: UUID
class Contact(ContactBase):
id: UUID
client_id: UUID
@ -35,67 +41,80 @@ class Contact(ContactBase):
class Config:
from_attributes = True
class AgentBase(BaseModel):
name: str = Field(..., description="Agent name (no spaces or special characters)")
description: Optional[str] = Field(None, description="Agent description")
type: str = Field(..., description="Agent type (llm, sequential, parallel, loop)")
model: Optional[str] = Field(None, description="Agent model (required only for llm type)")
api_key: Optional[str] = Field(None, description="Agent API Key (required only for llm type)")
model: Optional[str] = Field(
None, description="Agent model (required only for llm type)"
)
api_key: Optional[str] = Field(
None, description="Agent API Key (required only for llm type)"
)
instruction: Optional[str] = None
config: Union[LLMConfig, Dict[str, Any]] = Field(..., description="Agent configuration based on type")
config: Union[LLMConfig, Dict[str, Any]] = Field(
..., description="Agent configuration based on type"
)
@validator('name')
@validator("name")
def validate_name(cls, v):
if not re.match(r'^[a-zA-Z0-9_-]+$', v):
raise ValueError('Agent name cannot contain spaces or special characters')
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
raise ValueError("Agent name cannot contain spaces or special characters")
return v
@validator('type')
@validator("type")
def validate_type(cls, v):
if v not in ['llm', 'sequential', 'parallel', 'loop']:
raise ValueError('Invalid agent type. Must be: llm, sequential, parallel or loop')
if v not in ["llm", "sequential", "parallel", "loop"]:
raise ValueError(
"Invalid agent type. Must be: llm, sequential, parallel or loop"
)
return v
@validator('model')
@validator("model")
def validate_model(cls, v, values):
if 'type' in values and values['type'] == 'llm' and not v:
raise ValueError('Model is required for llm type agents')
if "type" in values and values["type"] == "llm" and not v:
raise ValueError("Model is required for llm type agents")
return v
@validator('api_key')
@validator("api_key")
def validate_api_key(cls, v, values):
if 'type' in values and values['type'] == 'llm' and not v:
raise ValueError('API Key is required for llm type agents')
if "type" in values and values["type"] == "llm" and not v:
raise ValueError("API Key is required for llm type agents")
return v
@validator('config')
@validator("config")
def validate_config(cls, v, values):
if 'type' not in values:
if "type" not in values:
return v
if values['type'] == 'llm':
if values["type"] == "llm":
if isinstance(v, dict):
try:
# Convert the dictionary to LLMConfig
v = LLMConfig(**v)
except Exception as e:
raise ValueError(f'Invalid LLM configuration for agent: {str(e)}')
raise ValueError(f"Invalid LLM configuration for agent: {str(e)}")
elif not isinstance(v, LLMConfig):
raise ValueError('Invalid LLM configuration for agent')
elif values['type'] in ['sequential', 'parallel', 'loop']:
raise ValueError("Invalid LLM configuration for agent")
elif values["type"] in ["sequential", "parallel", "loop"]:
if not isinstance(v, dict):
raise ValueError(f'Invalid configuration for agent {values["type"]}')
if 'sub_agents' not in v:
if "sub_agents" not in v:
raise ValueError(f'Agent {values["type"]} must have sub_agents')
if not isinstance(v['sub_agents'], list):
raise ValueError('sub_agents must be a list')
if not v['sub_agents']:
raise ValueError(f'Agent {values["type"]} must have at least one sub-agent')
if not isinstance(v["sub_agents"], list):
raise ValueError("sub_agents must be a list")
if not v["sub_agents"]:
raise ValueError(
f'Agent {values["type"]} must have at least one sub-agent'
)
return v
class AgentCreate(AgentBase):
client_id: UUID
class Agent(AgentBase):
id: UUID
client_id: UUID
@ -105,6 +124,7 @@ class Agent(AgentBase):
class Config:
from_attributes = True
class MCPServerBase(BaseModel):
name: str
description: Optional[str] = None
@ -113,9 +133,11 @@ class MCPServerBase(BaseModel):
tools: List[str] = Field(default_factory=list)
type: str = Field(default="official")
class MCPServerCreate(MCPServerBase):
pass
class MCPServer(MCPServerBase):
id: uuid.UUID
created_at: datetime
@ -124,15 +146,18 @@ class MCPServer(MCPServerBase):
class Config:
from_attributes = True
class ToolBase(BaseModel):
name: str
description: Optional[str] = None
config_json: Dict[str, Any] = Field(default_factory=dict)
environments: Dict[str, Any] = Field(default_factory=dict)
class ToolCreate(ToolBase):
pass
class Tool(ToolBase):
id: uuid.UUID
created_at: datetime

View File

@ -1,23 +1,28 @@
from pydantic import BaseModel, EmailStr, Field
from pydantic import BaseModel, EmailStr
from typing import Optional
from datetime import datetime
from uuid import UUID
class UserBase(BaseModel):
email: EmailStr
class UserCreate(UserBase):
password: str
name: str # For client creation
class AdminUserCreate(UserBase):
password: str
name: str
class UserLogin(BaseModel):
email: EmailStr
password: str
class UserResponse(UserBase):
id: UUID
client_id: Optional[UUID] = None
@ -29,22 +34,27 @@ class UserResponse(UserBase):
class Config:
from_attributes = True
class TokenResponse(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
sub: str # user email
exp: datetime
is_admin: bool
client_id: Optional[UUID] = None
class PasswordReset(BaseModel):
token: str
new_password: str
class ForgotPassword(BaseModel):
email: EmailStr
class MessageResponse(BaseModel):
message: str

View File

@ -13,11 +13,10 @@ from google.adk.agents.callback_context import CallbackContext
from google.adk.models import LlmResponse, LlmRequest
from google.adk.tools import load_memory
from typing import Optional
import logging
import os
import requests
import os
from datetime import datetime
logger = setup_logger(__name__)
@ -83,7 +82,7 @@ def before_model_callback(
llm_request.config.system_instruction = modified_text
logger.debug(
f"📝 System instruction updated with search results and history"
"📝 System instruction updated with search results and history"
)
else:
logger.warning("⚠️ No results found in the search")
@ -180,7 +179,9 @@ class AgentBuilder:
mcp_tools = []
mcp_exit_stack = None
if agent.config.get("mcp_servers"):
mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(agent.config, self.db)
mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(
agent.config, self.db
)
# Combine all tools
all_tools = custom_tools + mcp_tools
@ -201,10 +202,13 @@ class AgentBuilder:
# Check if load_memory is enabled
# before_model_callback_func = None
if agent.config.get("load_memory") == True:
if agent.config.get("load_memory"):
all_tools.append(load_memory)
# before_model_callback_func = before_model_callback
formatted_prompt = formatted_prompt + "\n\n<memory_instructions>ALWAYS use the load_memory tool to retrieve knowledge for your context</memory_instructions>\n\n"
formatted_prompt = (
formatted_prompt
+ "\n\n<memory_instructions>ALWAYS use the load_memory tool to retrieve knowledge for your context</memory_instructions>\n\n"
)
return (
LlmAgent(

View File

@ -22,9 +22,7 @@ async def run_agent(
db: Session,
):
try:
logger.info(
f"Starting execution of agent {agent_id} for contact {contact_id}"
)
logger.info(f"Starting execution of agent {agent_id} for contact {contact_id}")
logger.info(f"Received message: {message}")
get_root_agent = get_agent(db, agent_id)

View File

@ -216,9 +216,7 @@ async def update_agent(
return agent
except Exception as e:
db.rollback()
raise HTTPException(
status_code=500, detail=f"Error updating agent: {str(e)}"
)
raise HTTPException(status_code=500, detail=f"Error updating agent: {str(e)}")
def delete_agent(db: Session, agent_id: uuid.UUID) -> bool:

View File

@ -1,15 +1,15 @@
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from src.models.models import AuditLog, User
from src.models.models import AuditLog
from datetime import datetime
from fastapi import Request
from typing import Optional, Dict, Any, List
import uuid
import logging
import json
logger = logging.getLogger(__name__)
def create_audit_log(
db: Session,
user_id: Optional[uuid.UUID],
@ -17,7 +17,7 @@ def create_audit_log(
resource_type: str,
resource_id: Optional[str] = None,
details: Optional[Dict[str, Any]] = None,
request: Optional[Request] = None
request: Optional[Request] = None,
) -> Optional[AuditLog]:
"""
Create a new audit log
@ -39,7 +39,7 @@ def create_audit_log(
user_agent = None
if request:
ip_address = request.client.host if hasattr(request, 'client') else None
ip_address = request.client.host if hasattr(request, "client") else None
user_agent = request.headers.get("user-agent")
# Convert details to serializable format
@ -56,7 +56,7 @@ def create_audit_log(
resource_id=str(resource_id) if resource_id else None,
details=details,
ip_address=ip_address,
user_agent=user_agent
user_agent=user_agent,
)
db.add(audit_log)
@ -64,8 +64,8 @@ def create_audit_log(
db.refresh(audit_log)
logger.info(
f"Audit log created: {action} in {resource_type}" +
(f" (ID: {resource_id})" if resource_id else "")
f"Audit log created: {action} in {resource_type}"
+ (f" (ID: {resource_id})" if resource_id else "")
)
return audit_log
@ -78,6 +78,7 @@ def create_audit_log(
logger.error(f"Unexpected error creating audit log: {str(e)}")
return None
def get_audit_logs(
db: Session,
skip: int = 0,
@ -87,7 +88,7 @@ def get_audit_logs(
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
end_date: Optional[datetime] = None,
) -> List[AuditLog]:
"""
Get audit logs with optional filters

View File

@ -16,7 +16,10 @@ logger = logging.getLogger(__name__)
# Define OAuth2 authentication scheme with password flow
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> User:
async def get_current_user(
token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)
) -> User:
"""
Get the current user from the JWT token
@ -39,9 +42,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = De
try:
# Decode the token
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM]
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
# Extract token data
@ -61,7 +62,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = De
sub=email,
exp=datetime.fromtimestamp(exp),
is_admin=payload.get("is_admin", False),
client_id=payload.get("client_id")
client_id=payload.get("client_id"),
)
except JWTError as e:
@ -77,13 +78,15 @@ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = De
if not user.is_active:
logger.warning(f"Attempt to access inactive user: {user.email}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return user
async def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
async def get_current_active_user(
current_user: User = Depends(get_current_user),
) -> User:
"""
Check if the current user is active
@ -99,12 +102,14 @@ async def get_current_active_user(current_user: User = Depends(get_current_user)
if not current_user.is_active:
logger.warning(f"Attempt to access inactive user: {current_user.email}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return current_user
async def get_current_admin_user(current_user: User = Depends(get_current_user)) -> User:
async def get_current_admin_user(
current_user: User = Depends(get_current_user),
) -> User:
"""
Check if the current user is an administrator
@ -118,13 +123,16 @@ async def get_current_admin_user(current_user: User = Depends(get_current_user))
HTTPException: If the user is not an administrator
"""
if not current_user.is_admin:
logger.warning(f"Attempt to access admin by non-admin user: {current_user.email}")
logger.warning(
f"Attempt to access admin by non-admin user: {current_user.email}"
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. Restricted to administrators."
detail="Access denied. Restricted to administrators.",
)
return current_user
def create_access_token(user: User) -> str:
"""
Create a JWT access token for the user

View File

@ -11,6 +11,7 @@ import logging
logger = logging.getLogger(__name__)
def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]:
"""Search for a client by ID"""
try:
@ -23,9 +24,10 @@ def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]:
logger.error(f"Error searching for client {client_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for client"
detail="Error searching for client",
)
def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]:
"""Search for all clients with pagination"""
try:
@ -34,9 +36,10 @@ def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]:
logger.error(f"Error searching for clients: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for clients"
detail="Error searching for clients",
)
def create_client(db: Session, client: ClientCreate) -> Client:
"""Create a new client"""
try:
@ -51,10 +54,13 @@ def create_client(db: Session, client: ClientCreate) -> Client:
logger.error(f"Error creating client: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating client"
detail="Error creating client",
)
def update_client(db: Session, client_id: uuid.UUID, client: ClientCreate) -> Optional[Client]:
def update_client(
db: Session, client_id: uuid.UUID, client: ClientCreate
) -> Optional[Client]:
"""Updates an existing client"""
try:
db_client = get_client(db, client_id)
@ -73,9 +79,10 @@ def update_client(db: Session, client_id: uuid.UUID, client: ClientCreate) -> Op
logger.error(f"Error updating client {client_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating client"
detail="Error updating client",
)
def delete_client(db: Session, client_id: uuid.UUID) -> bool:
"""Removes a client"""
try:
@ -92,10 +99,13 @@ def delete_client(db: Session, client_id: uuid.UUID) -> bool:
logger.error(f"Error removing client {client_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error removing client"
detail="Error removing client",
)
def create_client_with_user(db: Session, client_data: ClientCreate, user_data: UserCreate) -> Tuple[Optional[Client], str]:
def create_client_with_user(
db: Session, client_data: ClientCreate, user_data: UserCreate
) -> Tuple[Optional[Client], str]:
"""
Creates a new client with an associated user

View File

@ -9,6 +9,7 @@ import logging
logger = logging.getLogger(__name__)
def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]:
"""Search for a contact by ID"""
try:
@ -21,20 +22,30 @@ def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]:
logger.error(f"Error searching for contact {contact_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for contact"
detail="Error searching for contact",
)
def get_contacts_by_client(db: Session, client_id: uuid.UUID, skip: int = 0, limit: int = 100) -> List[Contact]:
def get_contacts_by_client(
db: Session, client_id: uuid.UUID, skip: int = 0, limit: int = 100
) -> List[Contact]:
"""Search for contacts of a client with pagination"""
try:
return db.query(Contact).filter(Contact.client_id == client_id).offset(skip).limit(limit).all()
return (
db.query(Contact)
.filter(Contact.client_id == client_id)
.offset(skip)
.limit(limit)
.all()
)
except SQLAlchemyError as e:
logger.error(f"Error searching for contacts of client {client_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for contacts"
detail="Error searching for contacts",
)
def create_contact(db: Session, contact: ContactCreate) -> Contact:
"""Create a new contact"""
try:
@ -49,10 +60,13 @@ def create_contact(db: Session, contact: ContactCreate) -> Contact:
logger.error(f"Error creating contact: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating contact"
detail="Error creating contact",
)
def update_contact(db: Session, contact_id: uuid.UUID, contact: ContactCreate) -> Optional[Contact]:
def update_contact(
db: Session, contact_id: uuid.UUID, contact: ContactCreate
) -> Optional[Contact]:
"""Update an existing contact"""
try:
db_contact = get_contact(db, contact_id)
@ -71,9 +85,10 @@ def update_contact(db: Session, contact_id: uuid.UUID, contact: ContactCreate) -
logger.error(f"Error updating contact {contact_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating contact"
detail="Error updating contact",
)
def delete_contact(db: Session, contact_id: uuid.UUID) -> bool:
"""Remove a contact"""
try:
@ -90,5 +105,5 @@ def delete_contact(db: Session, contact_id: uuid.UUID) -> bool:
logger.error(f"Error removing contact {contact_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error removing contact"
detail="Error removing contact",
)

View File

@ -6,6 +6,7 @@ from src.utils.logger import setup_logger
logger = setup_logger(__name__)
class CustomToolBuilder:
def __init__(self):
self.tools = []
@ -53,7 +54,9 @@ class CustomToolBuilder:
# Adds default values to query params if they are not present
for param, value in values.items():
if param not in query_params and param not in parameters.get("path_params", {}):
if param not in query_params and param not in parameters.get(
"path_params", {}
):
query_params[param] = value
# Processa body parameters
@ -64,7 +67,11 @@ class CustomToolBuilder:
# Adds default values to body if they are not present
for param, value in values.items():
if param not in body_data and param not in query_params and param not in parameters.get("path_params", {}):
if (
param not in body_data
and param not in query_params
and param not in parameters.get("path_params", {})
):
body_data[param] = value
# Makes the HTTP request
@ -74,7 +81,7 @@ class CustomToolBuilder:
headers=processed_headers,
params=query_params,
json=body_data if body_data else None,
timeout=error_handling.get("timeout", 30)
timeout=error_handling.get("timeout", 30),
)
if response.status_code >= 400:
@ -87,10 +94,12 @@ class CustomToolBuilder:
except Exception as e:
logger.error(f"Error executing tool {name}: {str(e)}")
return json.dumps(error_handling.get("fallback_response", {
"error": "tool_execution_error",
"message": str(e)
}))
return json.dumps(
error_handling.get(
"fallback_response",
{"error": "tool_execution_error", "message": str(e)},
)
)
# Adds dynamic docstring based on the configuration
param_docs = []
@ -109,7 +118,9 @@ class CustomToolBuilder:
# Adds body parameters
for param, param_config in parameters.get("body_params", {}).items():
required = "Required" if param_config.get("required", False) else "Optional"
param_docs.append(f"{param} ({param_config['type']}, {required}): {param_config['description']}")
param_docs.append(
f"{param} ({param_config['type']}, {required}): {param_config['description']}"
)
# Adds default values
if values:

View File

@ -16,9 +16,10 @@ os.makedirs(templates_dir, exist_ok=True)
# Configure Jinja2 with the templates directory
env = Environment(
loader=FileSystemLoader(templates_dir),
autoescape=select_autoescape(['html', 'xml'])
autoescape=select_autoescape(["html", "xml"]),
)
def _render_template(template_name: str, context: dict) -> str:
"""
Render a template with the provided data
@ -37,6 +38,7 @@ def _render_template(template_name: str, context: dict) -> str:
logger.error(f"Error rendering template '{template_name}': {str(e)}")
return f"<p>Could not display email content. Please access {context.get('verification_link', '') or context.get('reset_link', '')}</p>"
def send_verification_email(email: str, token: str) -> bool:
"""
Send a verification email to the user
@ -56,11 +58,16 @@ def send_verification_email(email: str, token: str) -> bool:
verification_link = f"{settings.APP_URL}/auth/verify-email/{token}"
html_content = _render_template('verification_email', {
'verification_link': verification_link,
'user_name': email.split('@')[0], # Use part of the email as temporary name
'current_year': datetime.now().year
})
html_content = _render_template(
"verification_email",
{
"verification_link": verification_link,
"user_name": email.split("@")[
0
], # Use part of the email as temporary name
"current_year": datetime.now().year,
},
)
content = Content("text/html", html_content)
@ -71,13 +78,16 @@ def send_verification_email(email: str, token: str) -> bool:
logger.info(f"Verification email sent to {email}")
return True
else:
logger.error(f"Failed to send verification email to {email}. Status: {response.status_code}")
logger.error(
f"Failed to send verification email to {email}. Status: {response.status_code}"
)
return False
except Exception as e:
logger.error(f"Error sending verification email to {email}: {str(e)}")
return False
def send_password_reset_email(email: str, token: str) -> bool:
"""
Send a password reset email to the user
@ -97,11 +107,16 @@ def send_password_reset_email(email: str, token: str) -> bool:
reset_link = f"{settings.APP_URL}/reset-password?token={token}"
html_content = _render_template('password_reset', {
'reset_link': reset_link,
'user_name': email.split('@')[0], # Use part of the email as temporary name
'current_year': datetime.now().year
})
html_content = _render_template(
"password_reset",
{
"reset_link": reset_link,
"user_name": email.split("@")[
0
], # Use part of the email as temporary name
"current_year": datetime.now().year,
},
)
content = Content("text/html", html_content)
@ -112,13 +127,16 @@ def send_password_reset_email(email: str, token: str) -> bool:
logger.info(f"Password reset email sent to {email}")
return True
else:
logger.error(f"Failed to send password reset email to {email}. Status: {response.status_code}")
logger.error(
f"Failed to send password reset email to {email}. Status: {response.status_code}"
)
return False
except Exception as e:
logger.error(f"Error sending password reset email to {email}: {str(e)}")
return False
def send_welcome_email(email: str, user_name: str = None) -> bool:
"""
Send a welcome email to the user after verification
@ -138,11 +156,14 @@ def send_welcome_email(email: str, user_name: str = None) -> bool:
dashboard_link = f"{settings.APP_URL}/dashboard"
html_content = _render_template('welcome_email', {
'dashboard_link': dashboard_link,
'user_name': user_name or email.split('@')[0],
'current_year': datetime.now().year
})
html_content = _render_template(
"welcome_email",
{
"dashboard_link": dashboard_link,
"user_name": user_name or email.split("@")[0],
"current_year": datetime.now().year,
},
)
content = Content("text/html", html_content)
@ -153,14 +174,19 @@ def send_welcome_email(email: str, user_name: str = None) -> bool:
logger.info(f"Welcome email sent to {email}")
return True
else:
logger.error(f"Failed to send welcome email to {email}. Status: {response.status_code}")
logger.error(
f"Failed to send welcome email to {email}. Status: {response.status_code}"
)
return False
except Exception as e:
logger.error(f"Error sending welcome email to {email}: {str(e)}")
return False
def send_account_locked_email(email: str, reset_token: str, failed_attempts: int, time_period: str) -> bool:
def send_account_locked_email(
email: str, reset_token: str, failed_attempts: int, time_period: str
) -> bool:
"""
Send an email informing that the account has been locked after login attempts
@ -181,13 +207,16 @@ def send_account_locked_email(email: str, reset_token: str, failed_attempts: int
reset_link = f"{settings.APP_URL}/reset-password?token={reset_token}"
html_content = _render_template('account_locked', {
'reset_link': reset_link,
'user_name': email.split('@')[0],
'failed_attempts': failed_attempts,
'time_period': time_period,
'current_year': datetime.now().year
})
html_content = _render_template(
"account_locked",
{
"reset_link": reset_link,
"user_name": email.split("@")[0],
"failed_attempts": failed_attempts,
"time_period": time_period,
"current_year": datetime.now().year,
},
)
content = Content("text/html", html_content)
@ -198,7 +227,9 @@ def send_account_locked_email(email: str, reset_token: str, failed_attempts: int
logger.info(f"Account locked email sent to {email}")
return True
else:
logger.error(f"Failed to send account locked email to {email}. Status: {response.status_code}")
logger.error(
f"Failed to send account locked email to {email}. Status: {response.status_code}"
)
return False
except Exception as e:

View File

@ -9,6 +9,7 @@ import logging
logger = logging.getLogger(__name__)
def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]:
"""Search for an MCP server by ID"""
try:
@ -21,9 +22,10 @@ def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]:
logger.error(f"Error searching for MCP server {server_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for MCP server"
detail="Error searching for MCP server",
)
def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPServer]:
"""Search for all MCP servers with pagination"""
try:
@ -32,9 +34,10 @@ def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPSer
logger.error(f"Error searching for MCP servers: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for MCP servers"
detail="Error searching for MCP servers",
)
def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer:
"""Create a new MCP server"""
try:
@ -49,10 +52,13 @@ def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer:
logger.error(f"Error creating MCP server: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating MCP server"
detail="Error creating MCP server",
)
def update_mcp_server(db: Session, server_id: uuid.UUID, server: MCPServerCreate) -> Optional[MCPServer]:
def update_mcp_server(
db: Session, server_id: uuid.UUID, server: MCPServerCreate
) -> Optional[MCPServer]:
"""Update an existing MCP server"""
try:
db_server = get_mcp_server(db, server_id)
@ -71,9 +77,10 @@ def update_mcp_server(db: Session, server_id: uuid.UUID, server: MCPServerCreate
logger.error(f"Error updating MCP server {server_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating MCP server"
detail="Error updating MCP server",
)
def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool:
"""Remove an MCP server"""
try:
@ -90,5 +97,5 @@ def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool:
logger.error(f"Error removing MCP server {server_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error removing MCP server"
detail="Error removing MCP server",
)

View File

@ -12,20 +12,22 @@ from sqlalchemy.orm import Session
logger = setup_logger(__name__)
class MCPService:
def __init__(self):
self.tools = []
self.exit_stack = AsyncExitStack()
async def _connect_to_mcp_server(self, server_config: Dict[str, Any]) -> Tuple[List[Any], Optional[AsyncExitStack]]:
async def _connect_to_mcp_server(
self, server_config: Dict[str, Any]
) -> Tuple[List[Any], Optional[AsyncExitStack]]:
"""Connect to a specific MCP server and return its tools."""
try:
# Determines the type of server (local or remote)
if "url" in server_config:
# Remote server (SSE)
connection_params = SseServerParams(
url=server_config["url"],
headers=server_config.get("headers", {})
url=server_config["url"], headers=server_config.get("headers", {})
)
else:
# Local server (Stdio)
@ -39,9 +41,7 @@ class MCPService:
os.environ[key] = value
connection_params = StdioServerParameters(
command=command,
args=args,
env=env
command=command, args=args, env=env
)
tools, exit_stack = await MCPToolset.from_server(
@ -74,7 +74,9 @@ class MCPService:
return filtered_tools
def _filter_tools_by_agent(self, tools: List[Any], agent_tools: List[str]) -> List[Any]:
def _filter_tools_by_agent(
self, tools: List[Any], agent_tools: List[str]
) -> List[Any]:
"""Filters tools compatible with the agent."""
filtered_tools = []
for tool in tools:
@ -83,7 +85,9 @@ class MCPService:
filtered_tools.append(tool)
return filtered_tools
async def build_tools(self, mcp_config: Dict[str, Any], db: Session) -> Tuple[List[Any], AsyncExitStack]:
async def build_tools(
self, mcp_config: Dict[str, Any], db: Session
) -> Tuple[List[Any], AsyncExitStack]:
"""Builds a list of tools from multiple MCP servers."""
self.tools = []
self.exit_stack = AsyncExitStack()
@ -92,7 +96,7 @@ class MCPService:
for server in mcp_config.get("mcp_servers", []):
try:
# Search for the MCP server in the database
mcp_server = get_mcp_server(db, server['id'])
mcp_server = get_mcp_server(db, server["id"])
if not mcp_server:
logger.warning(f"Servidor MCP não encontrado: {server['id']}")
continue
@ -101,14 +105,16 @@ class MCPService:
server_config = mcp_server.config_json.copy()
# Replaces the environment variables in the config_json
if 'env' in server_config:
for key, value in server_config['env'].items():
if value.startswith('env@@'):
env_key = value.replace('env@@', '')
if env_key in server.get('envs', {}):
server_config['env'][key] = server['envs'][env_key]
if "env" in server_config:
for key, value in server_config["env"].items():
if value.startswith("env@@"):
env_key = value.replace("env@@", "")
if env_key in server.get("envs", {}):
server_config["env"][key] = server["envs"][env_key]
else:
logger.warning(f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}")
logger.warning(
f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}"
)
continue
logger.info(f"Connecting to MCP server: {mcp_server.name}")
@ -119,20 +125,28 @@ class MCPService:
filtered_tools = self._filter_incompatible_tools(tools)
# Filters tools compatible with the agent
agent_tools = server.get('tools', [])
filtered_tools = self._filter_tools_by_agent(filtered_tools, agent_tools)
agent_tools = server.get("tools", [])
filtered_tools = self._filter_tools_by_agent(
filtered_tools, agent_tools
)
self.tools.extend(filtered_tools)
# Registers the exit_stack with the AsyncExitStack
await self.exit_stack.enter_async_context(exit_stack)
logger.info(f"Connected successfully. Added {len(filtered_tools)} tools.")
logger.info(
f"Connected successfully. Added {len(filtered_tools)} tools."
)
else:
logger.warning(f"Failed to connect or no tools available for {mcp_server.name}")
logger.warning(
f"Failed to connect or no tools available for {mcp_server.name}"
)
except Exception as e:
logger.error(f"Error connecting to MCP server {server['id']}: {e}")
continue
logger.info(f"MCP Toolset created successfully. Total of {len(self.tools)} tools.")
logger.info(
f"MCP Toolset created successfully. Total of {len(self.tools)} tools."
)
return self.tools, self.exit_stack

View File

@ -0,0 +1,9 @@
from src.config.settings import settings
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.sessions import DatabaseSessionService
from google.adk.memory import InMemoryMemoryService
# Initialize service instances
session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING)
artifacts_service = InMemoryArtifactService()
memory_service = InMemoryMemoryService()

View File

@ -132,7 +132,7 @@ def get_session_events(
session = get_session_by_id(session_service, session_id)
# If we get here, the session exists (get_session_by_id already validates)
if not hasattr(session, 'events') or session.events is None:
if not hasattr(session, "events") or session.events is None:
return []
return session.events

View File

@ -9,6 +9,7 @@ import logging
logger = logging.getLogger(__name__)
def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]:
"""Search for a tool by ID"""
try:
@ -21,9 +22,10 @@ def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]:
logger.error(f"Error searching for tool {tool_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for tool"
detail="Error searching for tool",
)
def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]:
"""Search for all tools with pagination"""
try:
@ -32,9 +34,10 @@ def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]:
logger.error(f"Error searching for tools: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for tools"
detail="Error searching for tools",
)
def create_tool(db: Session, tool: ToolCreate) -> Tool:
"""Creates a new tool"""
try:
@ -49,9 +52,10 @@ def create_tool(db: Session, tool: ToolCreate) -> Tool:
logger.error(f"Error creating tool: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating tool"
detail="Error creating tool",
)
def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[Tool]:
"""Updates an existing tool"""
try:
@ -71,9 +75,10 @@ def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[T
logger.error(f"Error updating tool {tool_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating tool"
detail="Error updating tool",
)
def delete_tool(db: Session, tool_id: uuid.UUID) -> bool:
"""Remove a tool"""
try:
@ -90,5 +95,5 @@ def delete_tool(db: Session, tool_id: uuid.UUID) -> bool:
logger.error(f"Error removing tool {tool_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error removing tool"
detail="Error removing tool",
)

View File

@ -3,7 +3,10 @@ from sqlalchemy.exc import SQLAlchemyError
from src.models.models import User, Client
from src.schemas.user import UserCreate
from src.utils.security import get_password_hash, verify_password, generate_token
from src.services.email_service import send_verification_email, send_password_reset_email
from src.services.email_service import (
send_verification_email,
send_password_reset_email,
)
from datetime import datetime, timedelta
import uuid
import logging
@ -11,7 +14,13 @@ from typing import Optional, Tuple
logger = logging.getLogger(__name__)
def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, client_id: Optional[uuid.UUID] = None) -> Tuple[Optional[User], str]:
def create_user(
db: Session,
user_data: UserCreate,
is_admin: bool = False,
client_id: Optional[uuid.UUID] = None,
) -> Tuple[Optional[User], str]:
"""
Creates a new user in the system
@ -28,7 +37,9 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
# Check if email already exists
db_user = db.query(User).filter(User.email == user_data.email).first()
if db_user:
logger.warning(f"Attempt to register with existing email: {user_data.email}")
logger.warning(
f"Attempt to register with existing email: {user_data.email}"
)
return None, "Email already registered"
# Create verification token
@ -56,7 +67,7 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
is_active=False, # Inactive until email is verified
email_verified=False,
verification_token=verification_token,
verification_token_expiry=token_expiry
verification_token_expiry=token_expiry,
)
db.add(user)
db.commit()
@ -68,7 +79,10 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
# We don't do rollback here, we just log the error
logger.info(f"User created successfully: {user.email}")
return user, "User created successfully. Check your email to activate your account."
return (
user,
"User created successfully. Check your email to activate your account.",
)
except SQLAlchemyError as e:
db.rollback()
@ -79,6 +93,7 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
logger.error(f"Unexpected error creating user: {str(e)}")
return None, f"Unexpected error: {str(e)}"
def verify_email(db: Session, token: str) -> Tuple[bool, str]:
"""
Verify the user's email using the provided token
@ -111,7 +126,9 @@ def verify_email(db: Session, token: str) -> Tuple[bool, str]:
expiry = expiry.replace(tzinfo=now.tzinfo)
if expiry < now:
logger.warning(f"Attempt to verify with expired token for user: {user.email}")
logger.warning(
f"Attempt to verify with expired token for user: {user.email}"
)
return False, "Verification token expired"
# Update user
@ -133,6 +150,7 @@ def verify_email(db: Session, token: str) -> Tuple[bool, str]:
logger.error(f"Unexpected error verifying email: {str(e)}")
return False, f"Unexpected error: {str(e)}"
def resend_verification(db: Session, email: str) -> Tuple[bool, str]:
"""
Resend the verification email
@ -149,11 +167,15 @@ def resend_verification(db: Session, email: str) -> Tuple[bool, str]:
user = db.query(User).filter(User.email == email).first()
if not user:
logger.warning(f"Attempt to resend verification email for non-existent email: {email}")
logger.warning(
f"Attempt to resend verification email for non-existent email: {email}"
)
return False, "Email not found"
if user.email_verified:
logger.info(f"Attempt to resend verification email for already verified email: {email}")
logger.info(
f"Attempt to resend verification email for already verified email: {email}"
)
return False, "Email already verified"
# Generate new token
@ -184,6 +206,7 @@ def resend_verification(db: Session, email: str) -> Tuple[bool, str]:
logger.error(f"Unexpected error resending verification: {str(e)}")
return False, f"Unexpected error: {str(e)}"
def forgot_password(db: Session, email: str) -> Tuple[bool, str]:
"""
Initiates the password recovery process
@ -202,7 +225,10 @@ def forgot_password(db: Session, email: str) -> Tuple[bool, str]:
if not user:
# For security, we don't inform if the email exists or not
logger.info(f"Attempt to recover password for non-existent email: {email}")
return True, "If the email is registered, you will receive instructions to reset your password."
return (
True,
"If the email is registered, you will receive instructions to reset your password.",
)
# Generate reset token
reset_token = generate_token()
@ -221,7 +247,10 @@ def forgot_password(db: Session, email: str) -> Tuple[bool, str]:
return False, "Failed to send password reset email"
logger.info(f"Password reset email sent successfully to: {user.email}")
return True, "If the email is registered, you will receive instructions to reset your password."
return (
True,
"If the email is registered, you will receive instructions to reset your password.",
)
except SQLAlchemyError as e:
db.rollback()
@ -232,6 +261,7 @@ def forgot_password(db: Session, email: str) -> Tuple[bool, str]:
logger.error(f"Unexpected error processing password recovery: {str(e)}")
return False, f"Unexpected error: {str(e)}"
def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, str]:
"""
Resets the user's password using the provided token
@ -254,7 +284,9 @@ def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, st
# Check if the token has expired
if user.password_reset_expiry < datetime.utcnow():
logger.warning(f"Attempt to reset password with expired token for user: {user.email}")
logger.warning(
f"Attempt to reset password with expired token for user: {user.email}"
)
return False, "Password reset token expired"
# Update password
@ -264,7 +296,10 @@ def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, st
db.commit()
logger.info(f"Password reset successfully for user: {user.email}")
return True, "Password reset successfully. You can now login with your new password."
return (
True,
"Password reset successfully. You can now login with your new password.",
)
except SQLAlchemyError as e:
db.rollback()
@ -275,6 +310,7 @@ def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, st
logger.error(f"Unexpected error resetting password: {str(e)}")
return False, f"Unexpected error: {str(e)}"
def get_user_by_email(db: Session, email: str) -> Optional[User]:
"""
Searches for a user by email
@ -292,6 +328,7 @@ def get_user_by_email(db: Session, email: str) -> Optional[User]:
logger.error(f"Error searching for user by email: {str(e)}")
return None
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
"""
Authenticates a user with email and password
@ -313,6 +350,7 @@ def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
return None
return user
def get_admin_users(db: Session, skip: int = 0, limit: int = 100):
"""
Lists the admin users
@ -326,7 +364,7 @@ def get_admin_users(db: Session, skip: int = 0, limit: int = 100):
List[User]: List of admin users
"""
try:
users = db.query(User).filter(User.is_admin == True).offset(skip).limit(limit).all()
users = db.query(User).filter(User.is_admin).offset(skip).limit(limit).all()
logger.info(f"List of admins: {len(users)} found")
return users
@ -338,6 +376,7 @@ def get_admin_users(db: Session, skip: int = 0, limit: int = 100):
logger.error(f"Unexpected error listing admins: {str(e)}")
return []
def create_admin_user(db: Session, user_data: UserCreate) -> Tuple[Optional[User], str]:
"""
Creates a new admin user
@ -351,6 +390,7 @@ def create_admin_user(db: Session, user_data: UserCreate) -> Tuple[Optional[User
"""
return create_user(db, user_data, is_admin=True)
def deactivate_user(db: Session, user_id: uuid.UUID) -> Tuple[bool, str]:
"""
Deactivates a user (does not delete, only marks as inactive)

View File

@ -3,6 +3,7 @@ import os
import sys
from src.config.settings import settings
class CustomFormatter(logging.Formatter):
"""Custom formatter for logs"""
@ -12,14 +13,16 @@ class CustomFormatter(logging.Formatter):
bold_red = "\x1b[31;1m"
reset = "\x1b[0m"
format_template = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
format_template = (
"%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
)
FORMATS = {
logging.DEBUG: grey + format_template + reset,
logging.INFO: grey + format_template + reset,
logging.WARNING: yellow + format_template + reset,
logging.ERROR: red + format_template + reset,
logging.CRITICAL: bold_red + format_template + reset
logging.CRITICAL: bold_red + format_template + reset,
}
def format(self, record):
@ -27,6 +30,7 @@ class CustomFormatter(logging.Formatter):
formatter = logging.Formatter(log_fmt)
return formatter.format(record)
def setup_logger(name: str) -> logging.Logger:
"""
Configures a custom logger

View File

@ -11,7 +11,8 @@ from dataclasses import dataclass
logger = logging.getLogger(__name__)
# Fix bcrypt error with passlib
if not hasattr(bcrypt, '__about__'):
if not hasattr(bcrypt, "__about__"):
@dataclass
class BcryptAbout:
__version__: str = getattr(bcrypt, "__version__")
@ -21,31 +22,33 @@ if not hasattr(bcrypt, '__about__'):
# Context for password hashing using bcrypt
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_password_hash(password: str) -> str:
"""Creates a password hash"""
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verifies if the provided password matches the stored hash"""
return pwd_context.verify(plain_password, hashed_password)
def create_jwt_token(data: dict, expires_delta: timedelta = None) -> str:
"""Creates a JWT token"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(
minutes=settings.JWT_EXPIRATION_TIME
)
expire = datetime.utcnow() + timedelta(minutes=settings.JWT_EXPIRATION_TIME)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
)
return encoded_jwt
def generate_token(length: int = 32) -> str:
"""Generates a secure token for email verification or password reset"""
alphabet = string.ascii_letters + string.digits
token = ''.join(secrets.choice(alphabet) for _ in range(length))
token = "".join(secrets.choice(alphabet) for _ in range(length))
return token

1
tests/__init__.py Normal file
View File

@ -0,0 +1 @@
# Package initialization for tests

1
tests/api/__init__.py Normal file
View File

@ -0,0 +1 @@
# API tests package

11
tests/api/test_root.py Normal file
View File

@ -0,0 +1,11 @@
def test_read_root(client):
"""
Test that the root endpoint returns the correct response.
"""
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert "message" in data
assert "documentation" in data
assert "version" in data
assert "auth" in data

View File

@ -0,0 +1 @@
# Services tests package

View File

@ -0,0 +1,27 @@
from src.services.auth_service import create_access_token
from src.models.models import User
import uuid
def test_create_access_token():
"""
Test that an access token is created with the correct data.
"""
# Create a mock user
user = User(
id=uuid.uuid4(),
email="test@example.com",
hashed_password="hashed_password",
is_active=True,
is_admin=False,
name="Test User",
client_id=uuid.uuid4(),
)
# Create token
token = create_access_token(user)
# Simple validation
assert token is not None
assert isinstance(token, str)
assert len(token) > 0