chore: update project structure and add testing framework

This commit is contained in:
Davidson Gomes 2025-04-28 20:41:10 -03:00
parent 7af234ef48
commit e7e030dfd5
49 changed files with 1261 additions and 619 deletions

View File

@ -17,9 +17,16 @@
``` ```
src/ src/
├── api/ ├── api/
│ ├── routes.py # API routes definition │ ├── __init__.py # Package initialization
│ ├── auth_routes.py # Authentication routes (login, registration, etc.) │ ├── admin_routes.py # Admin routes for management interface
│ └── admin_routes.py # Protected admin routes │ ├── agent_routes.py # Routes to manage agents
│ ├── auth_routes.py # Authentication routes (login, registration)
│ ├── chat_routes.py # Routes for chat interactions with agents
│ ├── client_routes.py # Routes to manage clients
│ ├── contact_routes.py # Routes to manage contacts
│ ├── mcp_server_routes.py # Routes to manage MCP servers
│ ├── session_routes.py # Routes to manage chat sessions
│ └── tool_routes.py # Routes to manage tools for agents
├── config/ ├── config/
│ ├── database.py # Database configuration │ ├── database.py # Database configuration
│ └── settings.py # General settings │ └── settings.py # General settings
@ -30,18 +37,21 @@ src/
│ └── models.py # SQLAlchemy models │ └── models.py # SQLAlchemy models
├── schemas/ ├── schemas/
│ ├── schemas.py # Main Pydantic schemas │ ├── schemas.py # Main Pydantic schemas
│ ├── chat.py # Chat schemas
│ ├── user.py # User and authentication schemas │ ├── user.py # User and authentication schemas
│ └── audit.py # Audit logs schemas │ └── audit.py # Audit logs schemas
├── services/ ├── services/
│ ├── agent_service.py # Business logic for agents │ ├── agent_service.py # Business logic for agents
│ ├── agent_runner.py # Agent execution logic
│ ├── auth_service.py # JWT authentication logic
│ ├── audit_service.py # Audit logs logic
│ ├── client_service.py # Business logic for clients │ ├── client_service.py # Business logic for clients
│ ├── contact_service.py # Business logic for contacts │ ├── contact_service.py # Business logic for contacts
│ ├── mcp_server_service.py # Business logic for MCP servers
│ ├── tool_service.py # Business logic for tools
│ ├── user_service.py # User and authentication logic
│ ├── auth_service.py # JWT authentication logic
│ ├── email_service.py # Email sending service │ ├── email_service.py # Email sending service
│ └── audit_service.py # Audit logs logic │ ├── mcp_server_service.py # Business logic for MCP servers
│ ├── session_service.py # Business logic for chat sessions
│ ├── tool_service.py # Business logic for tools
│ └── user_service.py # User management logic
├── templates/ ├── templates/
│ ├── emails/ │ ├── emails/
│ │ ├── base_email.html # Base template with common structure and styles │ │ ├── base_email.html # Base template with common structure and styles
@ -49,7 +59,21 @@ src/
│ │ ├── password_reset.html # Password reset template │ │ ├── password_reset.html # Password reset template
│ │ ├── welcome_email.html # Welcome email after verification │ │ ├── welcome_email.html # Welcome email after verification
│ │ └── account_locked.html # Security alert for locked accounts │ │ └── account_locked.html # Security alert for locked accounts
├── tests/
│ ├── __init__.py # Package initialization
│ ├── api/
│ │ ├── __init__.py # Package initialization
│ │ ├── test_auth_routes.py # Test for authentication routes
│ │ └── test_root.py # Test for root endpoint
│ ├── models/
│ │ ├── __init__.py # Package initialization
│ │ ├── test_models.py # Test for models
│ ├── services/
│ │ ├── __init__.py # Package initialization
│ │ ├── test_auth_service.py # Test for authentication service
│ │ └── test_user_service.py # Test for user service
└── utils/ └── utils/
├── logger.py # Logger configuration
└── security.py # Security utilities (JWT, hash) └── security.py # Security utilities (JWT, hash)
``` ```
@ -63,6 +87,15 @@ src/
- Code examples in documentation must be in English - Code examples in documentation must be in English
- Commit messages must be in English - Commit messages must be in English
### Project Configuration
- Dependencies managed in `pyproject.toml` using modern Python packaging standards
- Development dependencies specified as optional dependencies in `pyproject.toml`
- Single source of truth for project metadata in `pyproject.toml`
- Build system configured to use setuptools
- Pytest configuration in `pyproject.toml` under `[tool.pytest.ini_options]`
- Code formatting with Black configured in `pyproject.toml`
- Linting with Flake8 configured in `.flake8`
### Schemas (Pydantic) ### Schemas (Pydantic)
- Use `BaseModel` as base for all schemas - Use `BaseModel` as base for all schemas
- Define fields with explicit types - Define fields with explicit types
@ -136,6 +169,28 @@ src/
- Indentation with 4 spaces - Indentation with 4 spaces
- Maximum of 79 characters per line - Maximum of 79 characters per line
## Commit Rules
- Use Conventional Commits format for all commit messages
- Format: `<type>(<scope>): <description>`
- Types:
- `feat`: A new feature
- `fix`: A bug fix
- `docs`: Documentation changes
- `style`: Changes that do not affect code meaning (formatting, etc.)
- `refactor`: Code changes that neither fix a bug nor add a feature
- `perf`: Performance improvements
- `test`: Adding or modifying tests
- `chore`: Changes to build process or auxiliary tools
- Scope is optional and should be the module or component affected
- Description should be concise, in the imperative mood, and not capitalized
- Use body for more detailed explanations if needed
- Reference issues in the footer with `Fixes #123` or `Relates to #123`
- Examples:
- `feat(auth): add password reset functionality`
- `fix(api): correct validation error in client registration`
- `docs: update API documentation for new endpoints`
- `refactor(services): improve error handling in authentication`
## Best Practices ## Best Practices
- Always validate input data - Always validate input data
- Implement appropriate logging - Implement appropriate logging
@ -163,6 +218,7 @@ src/
## Useful Commands ## Useful Commands
- `make run`: Start the server - `make run`: Start the server
- `make run-prod`: Start the server in production mode
- `make alembic-revision message="description"`: Create new migration - `make alembic-revision message="description"`: Create new migration
- `make alembic-upgrade`: Apply pending migrations - `make alembic-upgrade`: Apply pending migrations
- `make alembic-downgrade`: Revert last migration - `make alembic-downgrade`: Revert last migration
@ -170,3 +226,8 @@ src/
- `make alembic-reset`: Reset database to initial state - `make alembic-reset`: Reset database to initial state
- `make alembic-upgrade-cascade`: Force upgrade removing dependencies - `make alembic-upgrade-cascade`: Force upgrade removing dependencies
- `make clear-cache`: Clean project cache - `make clear-cache`: Clean project cache
- `make seed-all`: Run all database seeders
- `make lint`: Run linting checks with flake8
- `make format`: Format code with black
- `make install`: Install project for development
- `make install-dev`: Install project with development dependencies

View File

@ -1,3 +1,47 @@
# Environment and IDE
.venv
venv
.env
.idea
.vscode
__pycache__
*.pyc
*.pyo
*.pyd
.Python
.pytest_cache
.coverage
htmlcov/
.tox/
# Version control
.git
.github
.gitignore
# Logs and temp files
logs
*.log
tmp
.DS_Store
# Docker
.dockerignore
Dockerfile*
docker-compose*
# Documentation
README.md
LICENSE
docs/
# Development tools
tests/
.flake8
pyproject.toml
requirements-dev.txt
Makefile
# Ambiente virtual # Ambiente virtual
venv/ venv/
__pycache__/ __pycache__/

8
.flake8 Normal file
View File

@ -0,0 +1,8 @@
[flake8]
max-line-length = 88
exclude = .git,__pycache__,venv,alembic/versions/*
ignore = E203, W503, E501
per-file-ignores =
__init__.py: F401
src/models/models.py: E712
alembic/*: E711,E712,F401

77
.gitignore vendored
View File

@ -1,13 +1,10 @@
# Byte-compiled / optimized / DLL files # Python
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
# C extensions
*.so *.so
# Distribution / packaging
.Python .Python
env/
build/ build/
develop-eggs/ develop-eggs/
dist/ dist/
@ -19,19 +16,10 @@ lib64/
parts/ parts/
sdist/ sdist/
var/ var/
wheels/
*.egg-info/ *.egg-info/
.installed.cfg .installed.cfg
*.egg *.egg
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports # Unit test / coverage reports
htmlcov/ htmlcov/
.tox/ .tox/
@ -42,14 +30,57 @@ nosetests.xml
coverage.xml coverage.xml
*.cover *.cover
.hypothesis/ .hypothesis/
.pytest_cache/
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# IDE
.idea/
.vscode/
*.sublime-project
*.sublime-workspace
.DS_Store
# Logs
logs/
*.log
# Database
*.db
*.sqlite
*.sqlite3
backup/
# Local
local_settings.py
local.py
# Docker
.docker/
# Alembic versions
# alembic/versions/
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Translations # Translations
*.mo *.mo
*.pot *.pot
# Django stuff: # Django stuff:
*.log
local_settings.py
db.sqlite3 db.sqlite3
# Flask stuff: # Flask stuff:
@ -77,17 +108,6 @@ celerybeat-schedule
# SageMath parsed files # SageMath parsed files
*.sage.py *.sage.py
# Environments
.env
.venv
env/
venv/
.venv/
.env/
ENV/
env.bak/
venv.bak/
# Spyder project settings # Spyder project settings
.spyderproject .spyderproject
.spyproject .spyproject
@ -102,11 +122,8 @@ venv.bak/
.mypy_cache/ .mypy_cache/
# IDE # IDE
.idea/
.vscode/
*.swp *.swp
*.swo *.swo
# OS # OS
.DS_Store
Thumbs.db Thumbs.db

View File

@ -1,4 +1,4 @@
FROM python:3.11-slim FROM python:3.10-slim
# Define o diretório de trabalho # Define o diretório de trabalho
WORKDIR /app WORKDIR /app
@ -15,19 +15,19 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& apt-get clean \ && apt-get clean \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Copia os arquivos de requisitos # Copy project files
COPY requirements.txt .
# Instala as dependências
RUN pip install --no-cache-dir -r requirements.txt
# Copia o código-fonte
COPY . . COPY . .
# Install dependencies
RUN pip install --no-cache-dir -e .
# Configuração para produção # Configuração para produção
ENV PORT=8000 \ ENV PORT=8000 \
HOST=0.0.0.0 \ HOST=0.0.0.0 \
DEBUG=false DEBUG=false
# Expose port
EXPOSE 8000
# Define o comando de inicialização # Define o comando de inicialização
CMD alembic upgrade head && uvicorn src.main:app --host $HOST --port $PORT CMD alembic upgrade head && uvicorn src.main:app --host $HOST --port $PORT

View File

@ -1,42 +1,42 @@
.PHONY: migrate init revision upgrade downgrade run seed-admin seed-client seed-agents seed-mcp-servers seed-tools seed-contacts seed-all docker-build docker-up docker-down docker-logs .PHONY: migrate init revision upgrade downgrade run seed-admin seed-client seed-agents seed-mcp-servers seed-tools seed-contacts seed-all docker-build docker-up docker-down docker-logs lint format install install-dev venv
# Comandos do Alembic # Alembic commands
init: init:
alembic init alembics alembic init alembics
# make alembic-revision message="descrição da migração" # make alembic-revision message="migration description"
alembic-revision: alembic-revision:
alembic revision --autogenerate -m "$(message)" alembic revision --autogenerate -m "$(message)"
# Comando para atualizar o banco de dados # Command to update database to latest version
alembic-upgrade: alembic-upgrade:
alembic upgrade head alembic upgrade head
# Comando para voltar uma versão # Command to downgrade one version
alembic-downgrade: alembic-downgrade:
alembic downgrade -1 alembic downgrade -1
# Comando para rodar o servidor # Command to run the server
run: run:
uvicorn src.main:app --reload --host 0.0.0.0 --port 8000 --reload-dir src uvicorn src.main:app --reload --host 0.0.0.0 --port 8000 --reload-dir src
# Comando para limpar o cache em todas as pastas do projeto pastas pycache # Command to run the server in production mode
run-prod:
uvicorn src.main:app --host 0.0.0.0 --port 8000 --workers 4
# Command to clean cache in all project folders
clear-cache: clear-cache:
rm -rf ~/.cache/uv/environments-v2/* && find . -type d -name "__pycache__" -exec rm -r {} + rm -rf ~/.cache/uv/environments-v2/* && find . -type d -name "__pycache__" -exec rm -r {} +
# Comando para criar uma nova migração # Command to create a new migration and apply it
alembic-migrate: alembic-migrate:
alembic revision --autogenerate -m "$(message)" && alembic upgrade head alembic revision --autogenerate -m "$(message)" && alembic upgrade head
# Comando para resetar o banco de dados # Command to reset the database
alembic-reset: alembic-reset:
alembic downgrade base && alembic upgrade head alembic downgrade base && alembic upgrade head
# Comando para forçar upgrade com CASCADE # Commands to run seeders
alembic-upgrade-cascade:
psql -U postgres -d a2a_saas -c "DROP TABLE IF EXISTS events CASCADE; DROP TABLE IF EXISTS sessions CASCADE; DROP TABLE IF EXISTS user_states CASCADE; DROP TABLE IF EXISTS app_states CASCADE;" && alembic upgrade head
# Comandos para executar seeders
seed-admin: seed-admin:
python -m scripts.seeders.admin_seeder python -m scripts.seeders.admin_seeder
@ -58,7 +58,7 @@ seed-contacts:
seed-all: seed-all:
python -m scripts.run_seeders python -m scripts.run_seeders
# Comandos Docker # Docker commands
docker-build: docker-build:
docker-compose build docker-compose build
@ -72,4 +72,21 @@ docker-logs:
docker-compose logs -f docker-compose logs -f
docker-seed: docker-seed:
docker-compose exec api python -m scripts.run_seeders docker-compose exec api python -m scripts.run_seeders
# Testing, linting and formatting commands
lint:
flake8 src/ tests/
format:
black src/ tests/
# Virtual environment and installation commands
venv:
python -m venv venv
install:
pip install -e .
install-dev:
pip install -e ".[dev]"

53
conftest.py Normal file
View File

@ -0,0 +1,53 @@
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from src.config.database import Base, get_db
from src.main import app
# Use in-memory SQLite for tests
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture(scope="function")
def db_session():
"""Creates a fresh database session for each test."""
Base.metadata.create_all(bind=engine) # Create tables
connection = engine.connect()
transaction = connection.begin()
session = TestingSessionLocal(bind=connection)
# Use our test database instead of the standard one
def override_get_db():
try:
yield session
session.commit()
finally:
session.close()
app.dependency_overrides[get_db] = override_get_db
yield session # The test will run here
# Teardown
transaction.rollback()
connection.close()
Base.metadata.drop_all(bind=engine)
app.dependency_overrides.clear()
@pytest.fixture(scope="function")
def client(db_session):
"""Creates a FastAPI TestClient with database session fixture."""
with TestClient(app) as test_client:
yield test_client

89
pyproject.toml Normal file
View File

@ -0,0 +1,89 @@
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "evo-ai"
version = "1.0.0"
description = "API for executing AI agents"
readme = "README.md"
authors = [
{name = "EvoAI Team", email = "admin@evoai.com"}
]
requires-python = ">=3.10"
license = {text = "Proprietary"}
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"License :: Other/Proprietary License",
"Operating System :: OS Independent",
]
# Main dependencies
dependencies = [
"fastapi==0.115.12",
"uvicorn==0.34.2",
"pydantic==2.11.3",
"sqlalchemy==2.0.40",
"psycopg2-binary==2.9.10",
"google-cloud-aiplatform==1.90.0",
"python-dotenv==1.1.0",
"google-adk==0.3.0",
"litellm==1.67.4.post1",
"python-multipart==0.0.20",
"alembic==1.15.2",
"asyncpg==0.30.0",
"python-jose==3.4.0",
"passlib==1.7.4",
"sendgrid==6.11.0",
"pydantic-settings==2.9.1",
"fastapi_utils==0.8.0",
"bcrypt==4.3.0",
"jinja2==3.1.6",
]
[project.optional-dependencies]
dev = [
"black==25.1.0",
"flake8==7.2.0",
"pytest==8.3.5",
"pytest-cov==6.1.1",
"httpx==0.28.1",
"pytest-asyncio==0.26.0",
]
[tool.setuptools]
packages = ["src"]
[tool.black]
line-length = 88
target-version = ["py310"]
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
| venv
| alembic/versions
)/
'''
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = "test_*.py"
python_functions = "test_*"
python_classes = "Test*"
filterwarnings = [
"ignore::DeprecationWarning",
]
[tool.coverage.run]
source = ["src"]
omit = ["tests/*", "alembic/*"]

View File

@ -1,22 +0,0 @@
fastapi
uvicorn
pydantic
sqlalchemy
psycopg2
psycopg2-binary
google-cloud-aiplatform
python-dotenv
google-adk
litellm
python-multipart
alembic
asyncpg
# Novas dependências para autenticação
python-jose[cryptography]
passlib[bcrypt]
sendgrid
pydantic[email]
pydantic-settings
fastapi_utils
bcrypt
jinja2

6
setup.py Normal file
View File

@ -0,0 +1,6 @@
"""Setup script for the package."""
from setuptools import setup
if __name__ == "__main__":
setup()

View File

@ -1,3 +1,3 @@
""" """
Pacote principal da aplicação Main package of the application
""" """

View File

@ -1,14 +1,17 @@
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List, Optional
from datetime import datetime
import uuid import uuid
from src.config.database import get_db from src.config.database import get_db
from src.core.jwt_middleware import get_jwt_token, verify_admin from src.core.jwt_middleware import get_jwt_token, verify_admin
from src.schemas.audit import AuditLogResponse, AuditLogFilter from src.schemas.audit import AuditLogResponse, AuditLogFilter
from src.services.audit_service import get_audit_logs, create_audit_log from src.services.audit_service import get_audit_logs, create_audit_log
from src.services.user_service import get_admin_users, create_admin_user, deactivate_user from src.services.user_service import (
get_admin_users,
create_admin_user,
deactivate_user,
)
from src.schemas.user import UserResponse, AdminUserCreate from src.schemas.user import UserResponse, AdminUserCreate
router = APIRouter( router = APIRouter(
@ -18,6 +21,7 @@ router = APIRouter(
responses={403: {"description": "Permission denied"}}, responses={403: {"description": "Permission denied"}},
) )
# Audit routes # Audit routes
@router.get("/audit-logs", response_model=List[AuditLogResponse]) @router.get("/audit-logs", response_model=List[AuditLogResponse])
async def read_audit_logs( async def read_audit_logs(
@ -27,12 +31,12 @@ async def read_audit_logs(
): ):
""" """
Get audit logs with optional filters Get audit logs with optional filters
Args: Args:
filters: Filters for log search filters: Filters for log search
db: Database session db: Database session
payload: JWT token payload payload: JWT token payload
Returns: Returns:
List[AuditLogResponse]: List of audit logs List[AuditLogResponse]: List of audit logs
""" """
@ -45,9 +49,10 @@ async def read_audit_logs(
resource_type=filters.resource_type, resource_type=filters.resource_type,
resource_id=filters.resource_id, resource_id=filters.resource_id,
start_date=filters.start_date, start_date=filters.start_date,
end_date=filters.end_date end_date=filters.end_date,
) )
# Admin routes # Admin routes
@router.get("/users", response_model=List[UserResponse]) @router.get("/users", response_model=List[UserResponse])
async def read_admin_users( async def read_admin_users(
@ -58,18 +63,19 @@ async def read_admin_users(
): ):
""" """
List admin users List admin users
Args: Args:
skip: Number of records to skip skip: Number of records to skip
limit: Maximum number of records to return limit: Maximum number of records to return
db: Database session db: Database session
payload: JWT token payload payload: JWT token payload
Returns: Returns:
List[UserResponse]: List of admin users List[UserResponse]: List of admin users
""" """
return get_admin_users(db, skip, limit) return get_admin_users(db, skip, limit)
@router.post("/users", response_model=UserResponse, status_code=status.HTTP_201_CREATED) @router.post("/users", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def create_new_admin_user( async def create_new_admin_user(
user_data: AdminUserCreate, user_data: AdminUserCreate,
@ -79,16 +85,16 @@ async def create_new_admin_user(
): ):
""" """
Create a new admin user Create a new admin user
Args: Args:
user_data: User data to be created user_data: User data to be created
request: FastAPI Request object request: FastAPI Request object
db: Database session db: Database session
payload: JWT token payload payload: JWT token payload
Returns: Returns:
UserResponse: Created user data UserResponse: Created user data
Raises: Raises:
HTTPException: If there is an error in creation HTTPException: If there is an error in creation
""" """
@ -97,17 +103,14 @@ async def create_new_admin_user(
if not user_id: if not user_id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unable to identify the logged in user" detail="Unable to identify the logged in user",
) )
# Create admin user # Create admin user
user, message = create_admin_user(db, user_data) user, message = create_admin_user(db, user_data)
if not user: if not user:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message)
status_code=status.HTTP_400_BAD_REQUEST,
detail=message
)
# Register action in audit log # Register action in audit log
create_audit_log( create_audit_log(
db, db,
@ -116,11 +119,12 @@ async def create_new_admin_user(
resource_type="admin_user", resource_type="admin_user",
resource_id=str(user.id), resource_id=str(user.id),
details={"email": user.email}, details={"email": user.email},
request=request request=request,
) )
return user return user
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def deactivate_admin_user( async def deactivate_admin_user(
user_id: uuid.UUID, user_id: uuid.UUID,
@ -130,13 +134,13 @@ async def deactivate_admin_user(
): ):
""" """
Deactivate an admin user (does not delete, only deactivates) Deactivate an admin user (does not delete, only deactivates)
Args: Args:
user_id: ID of the user to be deactivated user_id: ID of the user to be deactivated
request: FastAPI Request object request: FastAPI Request object
db: Database session db: Database session
payload: JWT token payload payload: JWT token payload
Raises: Raises:
HTTPException: If there is an error in deactivation HTTPException: If there is an error in deactivation
""" """
@ -145,24 +149,21 @@ async def deactivate_admin_user(
if not current_user_id: if not current_user_id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unable to identify the logged in user" detail="Unable to identify the logged in user",
) )
# Do not allow deactivating yourself # Do not allow deactivating yourself
if str(user_id) == current_user_id: if str(user_id) == current_user_id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Unable to deactivate your own user" detail="Unable to deactivate your own user",
) )
# Deactivate user # Deactivate user
success, message = deactivate_user(db, user_id) success, message = deactivate_user(db, user_id)
if not success: if not success:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message)
status_code=status.HTTP_400_BAD_REQUEST,
detail=message
)
# Register action in audit log # Register action in audit log
create_audit_log( create_audit_log(
db, db,
@ -171,5 +172,5 @@ async def deactivate_admin_user(
resource_type="admin_user", resource_type="admin_user",
resource_id=str(user_id), resource_id=str(user_id),
details=None, details=None,
request=request request=request,
) )

View File

@ -7,10 +7,6 @@ from src.core.jwt_middleware import (
get_jwt_token, get_jwt_token,
verify_user_client, verify_user_client,
) )
from src.core.jwt_middleware import (
get_jwt_token,
verify_user_client,
)
from src.schemas.schemas import ( from src.schemas.schemas import (
Agent, Agent,
AgentCreate, AgentCreate,

View File

@ -162,9 +162,7 @@ async def login_for_access_token(form_data: UserLogin, db: Session = Depends(get
""" """
user = authenticate_user(db, form_data.email, form_data.password) user = authenticate_user(db, form_data.email, form_data.password)
if not user: if not user:
logger.warning( logger.warning(f"Login attempt with invalid credentials: {form_data.email}")
f"Login attempt with invalid credentials: {form_data.email}"
)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password", detail="Invalid email or password",

View File

@ -11,7 +11,11 @@ from src.services import (
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse
from src.services.agent_runner import run_agent from src.services.agent_runner import run_agent
from src.core.exceptions import AgentNotFoundError from src.core.exceptions import AgentNotFoundError
from src.main import session_service, artifacts_service, memory_service from src.services.service_providers import (
session_service,
artifacts_service,
memory_service,
)
from datetime import datetime from datetime import datetime
import logging import logging
@ -71,4 +75,4 @@ async def chat(
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
) )

View File

@ -19,7 +19,7 @@ from src.services.session_service import (
get_sessions_by_agent, get_sessions_by_agent,
get_sessions_by_client, get_sessions_by_client,
) )
from src.main import session_service from src.services.service_providers import session_service
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,6 +30,7 @@ router = APIRouter(
responses={404: {"description": "Not found"}}, responses={404: {"description": "Not found"}},
) )
# Session Routes # Session Routes
@router.get("/client/{client_id}", response_model=List[Adk_Session]) @router.get("/client/{client_id}", response_model=List[Adk_Session])
async def get_client_sessions( async def get_client_sessions(

View File

@ -10,9 +10,10 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() Base = declarative_base()
def get_db(): def get_db():
db = SessionLocal() db = SessionLocal()
try: try:
yield db yield db
finally: finally:
db.close() db.close()

View File

@ -1,55 +1,66 @@
from fastapi import HTTPException from fastapi import HTTPException
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
class BaseAPIException(HTTPException): class BaseAPIException(HTTPException):
"""Base class for API exceptions""" """Base class for API exceptions"""
def __init__( def __init__(
self, self,
status_code: int, status_code: int,
message: str, message: str,
error_code: str, error_code: str,
details: Optional[Dict[str, Any]] = None details: Optional[Dict[str, Any]] = None,
): ):
super().__init__(status_code=status_code, detail={ super().__init__(
"error": message, status_code=status_code,
"error_code": error_code, detail={
"details": details or {} "error": message,
}) "error_code": error_code,
"details": details or {},
},
)
class AgentNotFoundError(BaseAPIException): class AgentNotFoundError(BaseAPIException):
"""Exception when the agent is not found""" """Exception when the agent is not found"""
def __init__(self, agent_id: str): def __init__(self, agent_id: str):
super().__init__( super().__init__(
status_code=404, status_code=404,
message=f"Agent with ID {agent_id} not found", message=f"Agent with ID {agent_id} not found",
error_code="AGENT_NOT_FOUND" error_code="AGENT_NOT_FOUND",
) )
class InvalidParameterError(BaseAPIException): class InvalidParameterError(BaseAPIException):
"""Exception for invalid parameters""" """Exception for invalid parameters"""
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None): def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
super().__init__( super().__init__(
status_code=400, status_code=400,
message=message, message=message,
error_code="INVALID_PARAMETER", error_code="INVALID_PARAMETER",
details=details details=details,
) )
class InvalidRequestError(BaseAPIException): class InvalidRequestError(BaseAPIException):
"""Exception for invalid requests""" """Exception for invalid requests"""
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None): def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
super().__init__( super().__init__(
status_code=400, status_code=400,
message=message, message=message,
error_code="INVALID_REQUEST", error_code="INVALID_REQUEST",
details=details details=details,
) )
class InternalServerError(BaseAPIException): class InternalServerError(BaseAPIException):
"""Exception for server errors""" """Exception for server errors"""
def __init__(self, message: str = "Server error"): def __init__(self, message: str = "Server error"):
super().__init__( super().__init__(
status_code=500, status_code=500, message=message, error_code="INTERNAL_SERVER_ERROR"
message=message, )
error_code="INTERNAL_SERVER_ERROR"
)

View File

@ -13,16 +13,17 @@ logger = logging.getLogger(__name__)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict: async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict:
""" """
Extracts and validates the JWT token Extracts and validates the JWT token
Args: Args:
token: Token JWT token: Token JWT
Returns: Returns:
dict: Token payload data dict: Token payload data
Raises: Raises:
HTTPException: If the token is invalid HTTPException: If the token is invalid
""" """
@ -31,86 +32,90 @@ async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict:
detail="Invalid credentials", detail="Invalid credentials",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
try: try:
payload = jwt.decode( payload = jwt.decode(
token, token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM]
) )
email: str = payload.get("sub") email: str = payload.get("sub")
if email is None: if email is None:
logger.warning("Token without email (sub)") logger.warning("Token without email (sub)")
raise credentials_exception raise credentials_exception
exp = payload.get("exp") exp = payload.get("exp")
if exp is None or datetime.fromtimestamp(exp) < datetime.utcnow(): if exp is None or datetime.fromtimestamp(exp) < datetime.utcnow():
logger.warning(f"Token expired for {email}") logger.warning(f"Token expired for {email}")
raise credentials_exception raise credentials_exception
return payload return payload
except JWTError as e: except JWTError as e:
logger.error(f"Error decoding JWT token: {str(e)}") logger.error(f"Error decoding JWT token: {str(e)}")
raise credentials_exception raise credentials_exception
async def verify_user_client( async def verify_user_client(
payload: dict = Depends(get_jwt_token), payload: dict = Depends(get_jwt_token),
db: Session = Depends(get_db), db: Session = Depends(get_db),
required_client_id: UUID = None required_client_id: UUID = None,
) -> bool: ) -> bool:
""" """
Verifies if the user is associated with the specified client Verifies if the user is associated with the specified client
Args: Args:
payload: Token JWT payload payload: Token JWT payload
db: Database session db: Database session
required_client_id: Client ID to be verified required_client_id: Client ID to be verified
Returns: Returns:
bool: True se verificado com sucesso bool: True se verificado com sucesso
Raises: Raises:
HTTPException: If the user does not have permission HTTPException: If the user does not have permission
""" """
# Administrators have access to all clients # Administrators have access to all clients
if payload.get("is_admin", False): if payload.get("is_admin", False):
return True return True
# Para não-admins, verificar se o client_id corresponde # Para não-admins, verificar se o client_id corresponde
user_client_id = payload.get("client_id") user_client_id = payload.get("client_id")
if not user_client_id: if not user_client_id:
logger.warning(f"Non-admin user without client_id in token: {payload.get('sub')}") logger.warning(
f"Non-admin user without client_id in token: {payload.get('sub')}"
)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="User not associated with a client" detail="User not associated with a client",
) )
# If no client_id is specified to verify, any client is valid # If no client_id is specified to verify, any client is valid
if not required_client_id: if not required_client_id:
return True return True
# Verify if the user's client_id corresponds to the required_client_id # Verify if the user's client_id corresponds to the required_client_id
if str(user_client_id) != str(required_client_id): if str(user_client_id) != str(required_client_id):
logger.warning(f"Access denied: User {payload.get('sub')} tried to access resources of client {required_client_id}") logger.warning(
f"Access denied: User {payload.get('sub')} tried to access resources of client {required_client_id}"
)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to access resources of this client" detail="Access denied to access resources of this client",
) )
return True return True
async def verify_admin(payload: dict = Depends(get_jwt_token)) -> bool: async def verify_admin(payload: dict = Depends(get_jwt_token)) -> bool:
""" """
Verifies if the user is an administrator Verifies if the user is an administrator
Args: Args:
payload: Token JWT payload payload: Token JWT payload
Returns: Returns:
bool: True if the user is an administrator bool: True if the user is an administrator
Raises: Raises:
HTTPException: If the user is not an administrator HTTPException: If the user is not an administrator
""" """
@ -118,26 +123,29 @@ async def verify_admin(payload: dict = Depends(get_jwt_token)) -> bool:
logger.warning(f"Access denied to admin: User {payload.get('sub')}") logger.warning(f"Access denied to admin: User {payload.get('sub')}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. Restricted to administrators." detail="Access denied. Restricted to administrators.",
) )
return True return True
def get_current_user_client_id(payload: dict = Depends(get_jwt_token)) -> Optional[UUID]:
def get_current_user_client_id(
payload: dict = Depends(get_jwt_token),
) -> Optional[UUID]:
""" """
Gets the ID of the client associated with the current user Gets the ID of the client associated with the current user
Args: Args:
payload: Token JWT payload payload: Token JWT payload
Returns: Returns:
Optional[UUID]: Client ID or None if the user is an administrator Optional[UUID]: Client ID or None if the user is an administrator
""" """
if payload.get("is_admin", False): if payload.get("is_admin", False):
return None return None
client_id = payload.get("client_id") client_id = payload.get("client_id")
if client_id: if client_id:
return UUID(client_id) return UUID(client_id)
return None return None

View File

@ -1,35 +1,30 @@
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
# Add the root directory to PYTHONPATH
root_dir = Path(__file__).parent.parent
sys.path.append(str(root_dir))
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from src.config.database import engine, Base from src.config.database import engine, Base
from src.config.settings import settings from src.config.settings import settings
from src.utils.logger import setup_logger from src.utils.logger import setup_logger
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.sessions import DatabaseSessionService
from google.adk.memory import InMemoryMemoryService
# Initialize service instances # Necessary for other modules
session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING) from src.services.service_providers import session_service # noqa: F401
artifacts_service = InMemoryArtifactService() from src.services.service_providers import artifacts_service # noqa: F401
memory_service = InMemoryMemoryService() from src.services.service_providers import memory_service # noqa: F401
# Import routers after service initialization to avoid circular imports import src.api.auth_routes
from src.api.auth_routes import router as auth_router import src.api.admin_routes
from src.api.admin_routes import router as admin_router import src.api.chat_routes
from src.api.chat_routes import router as chat_router import src.api.session_routes
from src.api.session_routes import router as session_router import src.api.agent_routes
from src.api.agent_routes import router as agent_router import src.api.contact_routes
from src.api.contact_routes import router as contact_router import src.api.mcp_server_routes
from src.api.mcp_server_routes import router as mcp_server_router import src.api.tool_routes
from src.api.tool_routes import router as tool_router import src.api.client_routes
from src.api.client_routes import router as client_router
# Add the root directory to PYTHONPATH
root_dir = Path(__file__).parent.parent
sys.path.append(str(root_dir))
# Configure logger # Configure logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
@ -52,8 +47,7 @@ app.add_middleware(
# PostgreSQL configuration # PostgreSQL configuration
POSTGRES_CONNECTION_STRING = os.getenv( POSTGRES_CONNECTION_STRING = os.getenv(
"POSTGRES_CONNECTION_STRING", "POSTGRES_CONNECTION_STRING", "postgresql://postgres:root@localhost:5432/evo_ai"
"postgresql://postgres:root@localhost:5432/evo_ai"
) )
# Create database tables # Create database tables
@ -61,6 +55,17 @@ Base.metadata.create_all(bind=engine)
API_PREFIX = "/api/v1" API_PREFIX = "/api/v1"
# Define router references
auth_router = src.api.auth_routes.router
admin_router = src.api.admin_routes.router
chat_router = src.api.chat_routes.router
session_router = src.api.session_routes.router
agent_router = src.api.agent_routes.router
contact_router = src.api.contact_routes.router
mcp_server_router = src.api.mcp_server_routes.router
tool_router = src.api.tool_routes.router
client_router = src.api.client_routes.router
# Include routes # Include routes
app.include_router(auth_router, prefix=API_PREFIX) app.include_router(auth_router, prefix=API_PREFIX)
app.include_router(admin_router, prefix=API_PREFIX) app.include_router(admin_router, prefix=API_PREFIX)
@ -79,5 +84,5 @@ def read_root():
"message": "Welcome to Evo AI API", "message": "Welcome to Evo AI API",
"documentation": "/docs", "documentation": "/docs",
"version": settings.API_VERSION, "version": settings.API_VERSION,
"auth": "To access the API, use JWT authentication via '/api/v1/auth/login'" "auth": "To access the API, use JWT authentication via '/api/v1/auth/login'",
} }

View File

@ -1,9 +1,20 @@
from sqlalchemy import Column, String, UUID, DateTime, ForeignKey, JSON, Text, BigInteger, CheckConstraint, Boolean from sqlalchemy import (
Column,
String,
UUID,
DateTime,
ForeignKey,
JSON,
Text,
CheckConstraint,
Boolean,
)
from sqlalchemy.sql import func from sqlalchemy.sql import func
from sqlalchemy.orm import relationship, backref from sqlalchemy.orm import relationship, backref
from src.config.database import Base from src.config.database import Base
import uuid import uuid
class Client(Base): class Client(Base):
__tablename__ = "clients" __tablename__ = "clients"
@ -13,13 +24,16 @@ class Client(Base):
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())
class User(Base): class User(Base):
__tablename__ = "users" __tablename__ = "users"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
email = Column(String, unique=True, index=True, nullable=False) email = Column(String, unique=True, index=True, nullable=False)
password_hash = Column(String, nullable=False) password_hash = Column(String, nullable=False)
client_id = Column(UUID(as_uuid=True), ForeignKey("clients.id", ondelete="CASCADE"), nullable=True) client_id = Column(
UUID(as_uuid=True), ForeignKey("clients.id", ondelete="CASCADE"), nullable=True
)
is_active = Column(Boolean, default=False) is_active = Column(Boolean, default=False)
is_admin = Column(Boolean, default=False) is_admin = Column(Boolean, default=False)
email_verified = Column(Boolean, default=False) email_verified = Column(Boolean, default=False)
@ -29,9 +43,12 @@ class User(Base):
password_reset_expiry = Column(DateTime(timezone=True), nullable=True) password_reset_expiry = Column(DateTime(timezone=True), nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())
# Relationship with Client (One-to-One, optional for administrators) # Relationship with Client (One-to-One, optional for administrators)
client = relationship("Client", backref=backref("user", uselist=False, cascade="all, delete-orphan")) client = relationship(
"Client", backref=backref("user", uselist=False, cascade="all, delete-orphan")
)
class Contact(Base): class Contact(Base):
__tablename__ = "contacts" __tablename__ = "contacts"
@ -44,6 +61,7 @@ class Contact(Base):
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())
class Agent(Base): class Agent(Base):
__tablename__ = "agents" __tablename__ = "agents"
@ -60,21 +78,30 @@ class Agent(Base):
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())
__table_args__ = ( __table_args__ = (
CheckConstraint("type IN ('llm', 'sequential', 'parallel', 'loop')", name='check_agent_type'), CheckConstraint(
"type IN ('llm', 'sequential', 'parallel', 'loop')", name="check_agent_type"
),
) )
def to_dict(self): def to_dict(self):
"""Converts the object to a dictionary, converting UUIDs to strings""" """Converts the object to a dictionary, converting UUIDs to strings"""
result = {} result = {}
for key, value in self.__dict__.items(): for key, value in self.__dict__.items():
if key.startswith('_'): if key.startswith("_"):
continue continue
if isinstance(value, uuid.UUID): if isinstance(value, uuid.UUID):
result[key] = str(value) result[key] = str(value)
elif isinstance(value, dict): elif isinstance(value, dict):
result[key] = self._convert_dict(value) result[key] = self._convert_dict(value)
elif isinstance(value, list): elif isinstance(value, list):
result[key] = [self._convert_dict(item) if isinstance(item, dict) else str(item) if isinstance(item, uuid.UUID) else item for item in value] result[key] = [
(
self._convert_dict(item)
if isinstance(item, dict)
else str(item) if isinstance(item, uuid.UUID) else item
)
for item in value
]
else: else:
result[key] = value result[key] = value
return result return result
@ -88,11 +115,19 @@ class Agent(Base):
elif isinstance(value, dict): elif isinstance(value, dict):
result[key] = self._convert_dict(value) result[key] = self._convert_dict(value)
elif isinstance(value, list): elif isinstance(value, list):
result[key] = [self._convert_dict(item) if isinstance(item, dict) else str(item) if isinstance(item, uuid.UUID) else item for item in value] result[key] = [
(
self._convert_dict(item)
if isinstance(item, dict)
else str(item) if isinstance(item, uuid.UUID) else item
)
for item in value
]
else: else:
result[key] = value result[key] = value
return result return result
class MCPServer(Base): class MCPServer(Base):
__tablename__ = "mcp_servers" __tablename__ = "mcp_servers"
@ -105,11 +140,14 @@ class MCPServer(Base):
type = Column(String, nullable=False, default="official") type = Column(String, nullable=False, default="official")
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())
__table_args__ = ( __table_args__ = (
CheckConstraint("type IN ('official', 'community')", name='check_mcp_server_type'), CheckConstraint(
"type IN ('official', 'community')", name="check_mcp_server_type"
),
) )
class Tool(Base): class Tool(Base):
__tablename__ = "tools" __tablename__ = "tools"
@ -121,11 +159,12 @@ class Tool(Base):
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())
class Session(Base): class Session(Base):
__tablename__ = "sessions" __tablename__ = "sessions"
# The directive below makes Alembic ignore this table in migrations # The directive below makes Alembic ignore this table in migrations
__table_args__ = {'extend_existing': True, 'info': {'skip_autogenerate': True}} __table_args__ = {"extend_existing": True, "info": {"skip_autogenerate": True}}
id = Column(String, primary_key=True) id = Column(String, primary_key=True)
app_name = Column(String) app_name = Column(String)
user_id = Column(String) user_id = Column(String)
@ -133,11 +172,14 @@ class Session(Base):
create_time = Column(DateTime(timezone=True)) create_time = Column(DateTime(timezone=True))
update_time = Column(DateTime(timezone=True)) update_time = Column(DateTime(timezone=True))
class AuditLog(Base): class AuditLog(Base):
__tablename__ = "audit_logs" __tablename__ = "audit_logs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True) user_id = Column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
action = Column(String, nullable=False) action = Column(String, nullable=False)
resource_type = Column(String, nullable=False) resource_type = Column(String, nullable=False)
resource_id = Column(String, nullable=True) resource_id = Column(String, nullable=True)
@ -145,6 +187,6 @@ class AuditLog(Base):
ip_address = Column(String, nullable=True) ip_address = Column(String, nullable=True)
user_agent = Column(String, nullable=True) user_agent = Column(String, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationship with User # Relationship with User
user = relationship("User", backref="audit_logs") user = relationship("User", backref="audit_logs")

View File

@ -1,26 +1,38 @@
from typing import List, Optional, Dict, Any, Union from typing import List, Optional, Dict, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from uuid import UUID from uuid import UUID
class ToolConfig(BaseModel): class ToolConfig(BaseModel):
"""Configuration of a tool""" """Configuration of a tool"""
id: UUID id: UUID
envs: Dict[str, str] = Field(default_factory=dict, description="Environment variables of the tool") envs: Dict[str, str] = Field(
default_factory=dict, description="Environment variables of the tool"
)
class Config: class Config:
from_attributes = True from_attributes = True
class MCPServerConfig(BaseModel): class MCPServerConfig(BaseModel):
"""Configuration of an MCP server""" """Configuration of an MCP server"""
id: UUID id: UUID
envs: Dict[str, str] = Field(default_factory=dict, description="Environment variables of the server") envs: Dict[str, str] = Field(
tools: List[str] = Field(default_factory=list, description="List of tools of the server") default_factory=dict, description="Environment variables of the server"
)
tools: List[str] = Field(
default_factory=list, description="List of tools of the server"
)
class Config: class Config:
from_attributes = True from_attributes = True
class HTTPToolParameter(BaseModel): class HTTPToolParameter(BaseModel):
"""Parameter of an HTTP tool""" """Parameter of an HTTP tool"""
type: str type: str
required: bool required: bool
description: str description: str
@ -28,8 +40,10 @@ class HTTPToolParameter(BaseModel):
class Config: class Config:
from_attributes = True from_attributes = True
class HTTPToolParameters(BaseModel): class HTTPToolParameters(BaseModel):
"""Parameters of an HTTP tool""" """Parameters of an HTTP tool"""
path_params: Optional[Dict[str, str]] = None path_params: Optional[Dict[str, str]] = None
query_params: Optional[Dict[str, Union[str, List[str]]]] = None query_params: Optional[Dict[str, Union[str, List[str]]]] = None
body_params: Optional[Dict[str, HTTPToolParameter]] = None body_params: Optional[Dict[str, HTTPToolParameter]] = None
@ -37,8 +51,10 @@ class HTTPToolParameters(BaseModel):
class Config: class Config:
from_attributes = True from_attributes = True
class HTTPToolErrorHandling(BaseModel): class HTTPToolErrorHandling(BaseModel):
"""Configuration of error handling""" """Configuration of error handling"""
timeout: int timeout: int
retry_count: int retry_count: int
fallback_response: Dict[str, str] fallback_response: Dict[str, str]
@ -46,8 +62,10 @@ class HTTPToolErrorHandling(BaseModel):
class Config: class Config:
from_attributes = True from_attributes = True
class HTTPTool(BaseModel): class HTTPTool(BaseModel):
"""Configuration of an HTTP tool""" """Configuration of an HTTP tool"""
name: str name: str
method: str method: str
values: Dict[str, str] values: Dict[str, str]
@ -60,42 +78,72 @@ class HTTPTool(BaseModel):
class Config: class Config:
from_attributes = True from_attributes = True
class CustomTools(BaseModel): class CustomTools(BaseModel):
"""Configuration of custom tools""" """Configuration of custom tools"""
http_tools: List[HTTPTool] = Field(default_factory=list, description="List of HTTP tools")
http_tools: List[HTTPTool] = Field(
default_factory=list, description="List of HTTP tools"
)
class Config: class Config:
from_attributes = True from_attributes = True
class LLMConfig(BaseModel): class LLMConfig(BaseModel):
"""Configuration for LLM agents""" """Configuration for LLM agents"""
tools: Optional[List[ToolConfig]] = Field(default=None, description="List of available tools")
custom_tools: Optional[CustomTools] = Field(default=None, description="Custom tools") tools: Optional[List[ToolConfig]] = Field(
mcp_servers: Optional[List[MCPServerConfig]] = Field(default=None, description="List of MCP servers") default=None, description="List of available tools"
sub_agents: Optional[List[UUID]] = Field(default=None, description="List of IDs of sub-agents") )
custom_tools: Optional[CustomTools] = Field(
default=None, description="Custom tools"
)
mcp_servers: Optional[List[MCPServerConfig]] = Field(
default=None, description="List of MCP servers"
)
sub_agents: Optional[List[UUID]] = Field(
default=None, description="List of IDs of sub-agents"
)
class Config: class Config:
from_attributes = True from_attributes = True
class SequentialConfig(BaseModel): class SequentialConfig(BaseModel):
"""Configuration for sequential agents""" """Configuration for sequential agents"""
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents in execution order")
sub_agents: List[UUID] = Field(
..., description="List of IDs of sub-agents in execution order"
)
class Config: class Config:
from_attributes = True from_attributes = True
class ParallelConfig(BaseModel): class ParallelConfig(BaseModel):
"""Configuration for parallel agents""" """Configuration for parallel agents"""
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents for parallel execution")
sub_agents: List[UUID] = Field(
..., description="List of IDs of sub-agents for parallel execution"
)
class Config: class Config:
from_attributes = True from_attributes = True
class LoopConfig(BaseModel): class LoopConfig(BaseModel):
"""Configuration for loop agents""" """Configuration for loop agents"""
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents for loop execution")
max_iterations: Optional[int] = Field(default=None, description="Maximum number of iterations") sub_agents: List[UUID] = Field(
condition: Optional[str] = Field(default=None, description="Condition to stop the loop") ..., description="List of IDs of sub-agents for loop execution"
)
max_iterations: Optional[int] = Field(
default=None, description="Maximum number of iterations"
)
condition: Optional[str] = Field(
default=None, description="Condition to stop the loop"
)
class Config: class Config:
from_attributes = True from_attributes = True

View File

@ -3,30 +3,38 @@ from typing import Optional, Dict, Any
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
class AuditLogBase(BaseModel): class AuditLogBase(BaseModel):
"""Base schema for audit log""" """Base schema for audit log"""
action: str action: str
resource_type: str resource_type: str
resource_id: Optional[str] = None resource_id: Optional[str] = None
details: Optional[Dict[str, Any]] = None details: Optional[Dict[str, Any]] = None
class AuditLogCreate(AuditLogBase): class AuditLogCreate(AuditLogBase):
"""Schema for creating audit log""" """Schema for creating audit log"""
pass pass
class AuditLogResponse(AuditLogBase): class AuditLogResponse(AuditLogBase):
"""Schema for audit log response""" """Schema for audit log response"""
id: UUID id: UUID
user_id: Optional[UUID] = None user_id: Optional[UUID] = None
ip_address: Optional[str] = None ip_address: Optional[str] = None
user_agent: Optional[str] = None user_agent: Optional[str] = None
created_at: datetime created_at: datetime
class Config: class Config:
from_attributes = True from_attributes = True
class AuditLogFilter(BaseModel): class AuditLogFilter(BaseModel):
"""Schema for audit log search filters""" """Schema for audit log search filters"""
user_id: Optional[UUID] = None user_id: Optional[UUID] = None
action: Optional[str] = None action: Optional[str] = None
resource_type: Optional[str] = None resource_type: Optional[str] = None
@ -34,4 +42,4 @@ class AuditLogFilter(BaseModel):
start_date: Optional[datetime] = None start_date: Optional[datetime] = None
end_date: Optional[datetime] = None end_date: Optional[datetime] = None
skip: Optional[int] = Field(0, ge=0) skip: Optional[int] = Field(0, ge=0)
limit: Optional[int] = Field(100, ge=1, le=1000) limit: Optional[int] = Field(100, ge=1, le=1000)

View File

@ -1,21 +1,33 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
"""Schema for chat requests""" """Schema for chat requests"""
agent_id: str = Field(..., description="ID of the agent that will process the message")
contact_id: str = Field(..., description="ID of the contact that will process the message") agent_id: str = Field(
..., description="ID of the agent that will process the message"
)
contact_id: str = Field(
..., description="ID of the contact that will process the message"
)
message: str = Field(..., description="User message") message: str = Field(..., description="User message")
class ChatResponse(BaseModel): class ChatResponse(BaseModel):
"""Schema for chat responses""" """Schema for chat responses"""
response: str = Field(..., description="Agent response") response: str = Field(..., description="Agent response")
status: str = Field(..., description="Operation status") status: str = Field(..., description="Operation status")
error: Optional[str] = Field(None, description="Error message, if there is one") error: Optional[str] = Field(None, description="Error message, if there is one")
timestamp: str = Field(..., description="Timestamp of the response") timestamp: str = Field(..., description="Timestamp of the response")
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
"""Schema for error responses""" """Schema for error responses"""
error: str = Field(..., description="Error message") error: str = Field(..., description="Error message")
status_code: int = Field(..., description="HTTP status code of the error") status_code: int = Field(..., description="HTTP status code of the error")
details: Optional[Dict[str, Any]] = Field(None, description="Additional error details") details: Optional[Dict[str, Any]] = Field(
None, description="Additional error details"
)

View File

@ -4,15 +4,18 @@ from datetime import datetime
from uuid import UUID from uuid import UUID
import uuid import uuid
import re import re
from .agent_config import LLMConfig, SequentialConfig, ParallelConfig, LoopConfig from src.schemas.agent_config import LLMConfig
class ClientBase(BaseModel): class ClientBase(BaseModel):
name: str name: str
email: Optional[EmailStr] = None email: Optional[EmailStr] = None
class ClientCreate(ClientBase): class ClientCreate(ClientBase):
pass pass
class Client(ClientBase): class Client(ClientBase):
id: UUID id: UUID
created_at: datetime created_at: datetime
@ -20,14 +23,17 @@ class Client(ClientBase):
class Config: class Config:
from_attributes = True from_attributes = True
class ContactBase(BaseModel): class ContactBase(BaseModel):
ext_id: Optional[str] = None ext_id: Optional[str] = None
name: Optional[str] = None name: Optional[str] = None
meta: Optional[Dict[str, Any]] = Field(default_factory=dict) meta: Optional[Dict[str, Any]] = Field(default_factory=dict)
class ContactCreate(ContactBase): class ContactCreate(ContactBase):
client_id: UUID client_id: UUID
class Contact(ContactBase): class Contact(ContactBase):
id: UUID id: UUID
client_id: UUID client_id: UUID
@ -35,67 +41,80 @@ class Contact(ContactBase):
class Config: class Config:
from_attributes = True from_attributes = True
class AgentBase(BaseModel): class AgentBase(BaseModel):
name: str = Field(..., description="Agent name (no spaces or special characters)") name: str = Field(..., description="Agent name (no spaces or special characters)")
description: Optional[str] = Field(None, description="Agent description") description: Optional[str] = Field(None, description="Agent description")
type: str = Field(..., description="Agent type (llm, sequential, parallel, loop)") type: str = Field(..., description="Agent type (llm, sequential, parallel, loop)")
model: Optional[str] = Field(None, description="Agent model (required only for llm type)") model: Optional[str] = Field(
api_key: Optional[str] = Field(None, description="Agent API Key (required only for llm type)") None, description="Agent model (required only for llm type)"
)
api_key: Optional[str] = Field(
None, description="Agent API Key (required only for llm type)"
)
instruction: Optional[str] = None instruction: Optional[str] = None
config: Union[LLMConfig, Dict[str, Any]] = Field(..., description="Agent configuration based on type") config: Union[LLMConfig, Dict[str, Any]] = Field(
..., description="Agent configuration based on type"
)
@validator('name') @validator("name")
def validate_name(cls, v): def validate_name(cls, v):
if not re.match(r'^[a-zA-Z0-9_-]+$', v): if not re.match(r"^[a-zA-Z0-9_-]+$", v):
raise ValueError('Agent name cannot contain spaces or special characters') raise ValueError("Agent name cannot contain spaces or special characters")
return v return v
@validator('type') @validator("type")
def validate_type(cls, v): def validate_type(cls, v):
if v not in ['llm', 'sequential', 'parallel', 'loop']: if v not in ["llm", "sequential", "parallel", "loop"]:
raise ValueError('Invalid agent type. Must be: llm, sequential, parallel or loop') raise ValueError(
"Invalid agent type. Must be: llm, sequential, parallel or loop"
)
return v return v
@validator('model') @validator("model")
def validate_model(cls, v, values): def validate_model(cls, v, values):
if 'type' in values and values['type'] == 'llm' and not v: if "type" in values and values["type"] == "llm" and not v:
raise ValueError('Model is required for llm type agents') raise ValueError("Model is required for llm type agents")
return v return v
@validator('api_key') @validator("api_key")
def validate_api_key(cls, v, values): def validate_api_key(cls, v, values):
if 'type' in values and values['type'] == 'llm' and not v: if "type" in values and values["type"] == "llm" and not v:
raise ValueError('API Key is required for llm type agents') raise ValueError("API Key is required for llm type agents")
return v return v
@validator('config') @validator("config")
def validate_config(cls, v, values): def validate_config(cls, v, values):
if 'type' not in values: if "type" not in values:
return v return v
if values['type'] == 'llm': if values["type"] == "llm":
if isinstance(v, dict): if isinstance(v, dict):
try: try:
# Convert the dictionary to LLMConfig # Convert the dictionary to LLMConfig
v = LLMConfig(**v) v = LLMConfig(**v)
except Exception as e: except Exception as e:
raise ValueError(f'Invalid LLM configuration for agent: {str(e)}') raise ValueError(f"Invalid LLM configuration for agent: {str(e)}")
elif not isinstance(v, LLMConfig): elif not isinstance(v, LLMConfig):
raise ValueError('Invalid LLM configuration for agent') raise ValueError("Invalid LLM configuration for agent")
elif values['type'] in ['sequential', 'parallel', 'loop']: elif values["type"] in ["sequential", "parallel", "loop"]:
if not isinstance(v, dict): if not isinstance(v, dict):
raise ValueError(f'Invalid configuration for agent {values["type"]}') raise ValueError(f'Invalid configuration for agent {values["type"]}')
if 'sub_agents' not in v: if "sub_agents" not in v:
raise ValueError(f'Agent {values["type"]} must have sub_agents') raise ValueError(f'Agent {values["type"]} must have sub_agents')
if not isinstance(v['sub_agents'], list): if not isinstance(v["sub_agents"], list):
raise ValueError('sub_agents must be a list') raise ValueError("sub_agents must be a list")
if not v['sub_agents']: if not v["sub_agents"]:
raise ValueError(f'Agent {values["type"]} must have at least one sub-agent') raise ValueError(
f'Agent {values["type"]} must have at least one sub-agent'
)
return v return v
class AgentCreate(AgentBase): class AgentCreate(AgentBase):
client_id: UUID client_id: UUID
class Agent(AgentBase): class Agent(AgentBase):
id: UUID id: UUID
client_id: UUID client_id: UUID
@ -105,6 +124,7 @@ class Agent(AgentBase):
class Config: class Config:
from_attributes = True from_attributes = True
class MCPServerBase(BaseModel): class MCPServerBase(BaseModel):
name: str name: str
description: Optional[str] = None description: Optional[str] = None
@ -113,9 +133,11 @@ class MCPServerBase(BaseModel):
tools: List[str] = Field(default_factory=list) tools: List[str] = Field(default_factory=list)
type: str = Field(default="official") type: str = Field(default="official")
class MCPServerCreate(MCPServerBase): class MCPServerCreate(MCPServerBase):
pass pass
class MCPServer(MCPServerBase): class MCPServer(MCPServerBase):
id: uuid.UUID id: uuid.UUID
created_at: datetime created_at: datetime
@ -124,19 +146,22 @@ class MCPServer(MCPServerBase):
class Config: class Config:
from_attributes = True from_attributes = True
class ToolBase(BaseModel): class ToolBase(BaseModel):
name: str name: str
description: Optional[str] = None description: Optional[str] = None
config_json: Dict[str, Any] = Field(default_factory=dict) config_json: Dict[str, Any] = Field(default_factory=dict)
environments: Dict[str, Any] = Field(default_factory=dict) environments: Dict[str, Any] = Field(default_factory=dict)
class ToolCreate(ToolBase): class ToolCreate(ToolBase):
pass pass
class Tool(ToolBase): class Tool(ToolBase):
id: uuid.UUID id: uuid.UUID
created_at: datetime created_at: datetime
updated_at: Optional[datetime] = None updated_at: Optional[datetime] = None
class Config: class Config:
from_attributes = True from_attributes = True

View File

@ -1,23 +1,28 @@
from pydantic import BaseModel, EmailStr, Field from pydantic import BaseModel, EmailStr
from typing import Optional from typing import Optional
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
class UserBase(BaseModel): class UserBase(BaseModel):
email: EmailStr email: EmailStr
class UserCreate(UserBase): class UserCreate(UserBase):
password: str password: str
name: str # For client creation name: str # For client creation
class AdminUserCreate(UserBase): class AdminUserCreate(UserBase):
password: str password: str
name: str name: str
class UserLogin(BaseModel): class UserLogin(BaseModel):
email: EmailStr email: EmailStr
password: str password: str
class UserResponse(UserBase): class UserResponse(UserBase):
id: UUID id: UUID
client_id: Optional[UUID] = None client_id: Optional[UUID] = None
@ -25,26 +30,31 @@ class UserResponse(UserBase):
email_verified: bool email_verified: bool
is_admin: bool is_admin: bool
created_at: datetime created_at: datetime
class Config: class Config:
from_attributes = True from_attributes = True
class TokenResponse(BaseModel): class TokenResponse(BaseModel):
access_token: str access_token: str
token_type: str token_type: str
class TokenData(BaseModel): class TokenData(BaseModel):
sub: str # user email sub: str # user email
exp: datetime exp: datetime
is_admin: bool is_admin: bool
client_id: Optional[UUID] = None client_id: Optional[UUID] = None
class PasswordReset(BaseModel): class PasswordReset(BaseModel):
token: str token: str
new_password: str new_password: str
class ForgotPassword(BaseModel): class ForgotPassword(BaseModel):
email: EmailStr email: EmailStr
class MessageResponse(BaseModel): class MessageResponse(BaseModel):
message: str message: str

View File

@ -1 +1 @@
from .agent_runner import run_agent from .agent_runner import run_agent

View File

@ -13,11 +13,10 @@ from google.adk.agents.callback_context import CallbackContext
from google.adk.models import LlmResponse, LlmRequest from google.adk.models import LlmResponse, LlmRequest
from google.adk.tools import load_memory from google.adk.tools import load_memory
from typing import Optional
import logging
import os
import requests import requests
import os
from datetime import datetime from datetime import datetime
logger = setup_logger(__name__) logger = setup_logger(__name__)
@ -83,7 +82,7 @@ def before_model_callback(
llm_request.config.system_instruction = modified_text llm_request.config.system_instruction = modified_text
logger.debug( logger.debug(
f"📝 System instruction updated with search results and history" "📝 System instruction updated with search results and history"
) )
else: else:
logger.warning("⚠️ No results found in the search") logger.warning("⚠️ No results found in the search")
@ -180,11 +179,13 @@ class AgentBuilder:
mcp_tools = [] mcp_tools = []
mcp_exit_stack = None mcp_exit_stack = None
if agent.config.get("mcp_servers"): if agent.config.get("mcp_servers"):
mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(agent.config, self.db) mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(
agent.config, self.db
)
# Combine all tools # Combine all tools
all_tools = custom_tools + mcp_tools all_tools = custom_tools + mcp_tools
now = datetime.now() now = datetime.now()
current_datetime = now.strftime("%d/%m/%Y %H:%M") current_datetime = now.strftime("%d/%m/%Y %H:%M")
current_day_of_week = now.strftime("%A") current_day_of_week = now.strftime("%A")
@ -201,10 +202,13 @@ class AgentBuilder:
# Check if load_memory is enabled # Check if load_memory is enabled
# before_model_callback_func = None # before_model_callback_func = None
if agent.config.get("load_memory") == True: if agent.config.get("load_memory"):
all_tools.append(load_memory) all_tools.append(load_memory)
# before_model_callback_func = before_model_callback # before_model_callback_func = before_model_callback
formatted_prompt = formatted_prompt + "\n\n<memory_instructions>ALWAYS use the load_memory tool to retrieve knowledge for your context</memory_instructions>\n\n" formatted_prompt = (
formatted_prompt
+ "\n\n<memory_instructions>ALWAYS use the load_memory tool to retrieve knowledge for your context</memory_instructions>\n\n"
)
return ( return (
LlmAgent( LlmAgent(

View File

@ -22,9 +22,7 @@ async def run_agent(
db: Session, db: Session,
): ):
try: try:
logger.info( logger.info(f"Starting execution of agent {agent_id} for contact {contact_id}")
f"Starting execution of agent {agent_id} for contact {contact_id}"
)
logger.info(f"Received message: {message}") logger.info(f"Received message: {message}")
get_root_agent = get_agent(db, agent_id) get_root_agent = get_agent(db, agent_id)
@ -77,15 +75,15 @@ async def run_agent(
if event.is_final_response() and event.content and event.content.parts: if event.is_final_response() and event.content and event.content.parts:
final_response_text = event.content.parts[0].text final_response_text = event.content.parts[0].text
logger.info(f"Final response received: {final_response_text}") logger.info(f"Final response received: {final_response_text}")
completed_session = session_service.get_session( completed_session = session_service.get_session(
app_name=agent_id, app_name=agent_id,
user_id=contact_id, user_id=contact_id,
session_id=session_id, session_id=session_id,
) )
memory_service.add_session_to_memory(completed_session) memory_service.add_session_to_memory(completed_session)
finally: finally:
# Ensure the exit_stack is closed correctly # Ensure the exit_stack is closed correctly
if exit_stack: if exit_stack:

View File

@ -216,9 +216,7 @@ async def update_agent(
return agent return agent
except Exception as e: except Exception as e:
db.rollback() db.rollback()
raise HTTPException( raise HTTPException(status_code=500, detail=f"Error updating agent: {str(e)}")
status_code=500, detail=f"Error updating agent: {str(e)}"
)
def delete_agent(db: Session, agent_id: uuid.UUID) -> bool: def delete_agent(db: Session, agent_id: uuid.UUID) -> bool:

View File

@ -1,15 +1,15 @@
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from src.models.models import AuditLog, User from src.models.models import AuditLog
from datetime import datetime from datetime import datetime
from fastapi import Request from fastapi import Request
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
import uuid import uuid
import logging import logging
import json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def create_audit_log( def create_audit_log(
db: Session, db: Session,
user_id: Optional[uuid.UUID], user_id: Optional[uuid.UUID],
@ -17,11 +17,11 @@ def create_audit_log(
resource_type: str, resource_type: str,
resource_id: Optional[str] = None, resource_id: Optional[str] = None,
details: Optional[Dict[str, Any]] = None, details: Optional[Dict[str, Any]] = None,
request: Optional[Request] = None request: Optional[Request] = None,
) -> Optional[AuditLog]: ) -> Optional[AuditLog]:
""" """
Create a new audit log Create a new audit log
Args: Args:
db: Database session db: Database session
user_id: User ID that performed the action (or None if anonymous) user_id: User ID that performed the action (or None if anonymous)
@ -30,25 +30,25 @@ def create_audit_log(
resource_id: Resource ID (optional) resource_id: Resource ID (optional)
details: Additional details of the action (optional) details: Additional details of the action (optional)
request: FastAPI Request object (optional, to get IP and User-Agent) request: FastAPI Request object (optional, to get IP and User-Agent)
Returns: Returns:
Optional[AuditLog]: Created audit log or None in case of error Optional[AuditLog]: Created audit log or None in case of error
""" """
try: try:
ip_address = None ip_address = None
user_agent = None user_agent = None
if request: if request:
ip_address = request.client.host if hasattr(request, 'client') else None ip_address = request.client.host if hasattr(request, "client") else None
user_agent = request.headers.get("user-agent") user_agent = request.headers.get("user-agent")
# Convert details to serializable format # Convert details to serializable format
if details: if details:
# Convert UUIDs to strings # Convert UUIDs to strings
for key, value in details.items(): for key, value in details.items():
if isinstance(value, uuid.UUID): if isinstance(value, uuid.UUID):
details[key] = str(value) details[key] = str(value)
audit_log = AuditLog( audit_log = AuditLog(
user_id=user_id, user_id=user_id,
action=action, action=action,
@ -56,20 +56,20 @@ def create_audit_log(
resource_id=str(resource_id) if resource_id else None, resource_id=str(resource_id) if resource_id else None,
details=details, details=details,
ip_address=ip_address, ip_address=ip_address,
user_agent=user_agent user_agent=user_agent,
) )
db.add(audit_log) db.add(audit_log)
db.commit() db.commit()
db.refresh(audit_log) db.refresh(audit_log)
logger.info( logger.info(
f"Audit log created: {action} in {resource_type}" + f"Audit log created: {action} in {resource_type}"
(f" (ID: {resource_id})" if resource_id else "") + (f" (ID: {resource_id})" if resource_id else "")
) )
return audit_log return audit_log
except SQLAlchemyError as e: except SQLAlchemyError as e:
db.rollback() db.rollback()
logger.error(f"Error creating audit log: {str(e)}") logger.error(f"Error creating audit log: {str(e)}")
@ -78,6 +78,7 @@ def create_audit_log(
logger.error(f"Unexpected error creating audit log: {str(e)}") logger.error(f"Unexpected error creating audit log: {str(e)}")
return None return None
def get_audit_logs( def get_audit_logs(
db: Session, db: Session,
skip: int = 0, skip: int = 0,
@ -87,11 +88,11 @@ def get_audit_logs(
resource_type: Optional[str] = None, resource_type: Optional[str] = None,
resource_id: Optional[str] = None, resource_id: Optional[str] = None,
start_date: Optional[datetime] = None, start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None end_date: Optional[datetime] = None,
) -> List[AuditLog]: ) -> List[AuditLog]:
""" """
Get audit logs with optional filters Get audit logs with optional filters
Args: Args:
db: Database session db: Database session
skip: Number of records to skip skip: Number of records to skip
@ -102,35 +103,35 @@ def get_audit_logs(
resource_id: Filter by resource ID resource_id: Filter by resource ID
start_date: Start date start_date: Start date
end_date: End date end_date: End date
Returns: Returns:
List[AuditLog]: List of audit logs List[AuditLog]: List of audit logs
""" """
query = db.query(AuditLog) query = db.query(AuditLog)
# Apply filters, if provided # Apply filters, if provided
if user_id: if user_id:
query = query.filter(AuditLog.user_id == user_id) query = query.filter(AuditLog.user_id == user_id)
if action: if action:
query = query.filter(AuditLog.action == action) query = query.filter(AuditLog.action == action)
if resource_type: if resource_type:
query = query.filter(AuditLog.resource_type == resource_type) query = query.filter(AuditLog.resource_type == resource_type)
if resource_id: if resource_id:
query = query.filter(AuditLog.resource_id == resource_id) query = query.filter(AuditLog.resource_id == resource_id)
if start_date: if start_date:
query = query.filter(AuditLog.created_at >= start_date) query = query.filter(AuditLog.created_at >= start_date)
if end_date: if end_date:
query = query.filter(AuditLog.created_at <= end_date) query = query.filter(AuditLog.created_at <= end_date)
# Order by creation date (most recent first) # Order by creation date (most recent first)
query = query.order_by(AuditLog.created_at.desc()) query = query.order_by(AuditLog.created_at.desc())
# Apply pagination # Apply pagination
query = query.offset(skip).limit(limit) query = query.offset(skip).limit(limit)
return query.all() return query.all()

View File

@ -16,17 +16,20 @@ logger = logging.getLogger(__name__)
# Define OAuth2 authentication scheme with password flow # Define OAuth2 authentication scheme with password flow
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> User:
async def get_current_user(
token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)
) -> User:
""" """
Get the current user from the JWT token Get the current user from the JWT token
Args: Args:
token: JWT token token: JWT token
db: Database session db: Database session
Returns: Returns:
User: Current user User: Current user
Raises: Raises:
HTTPException: If the token is invalid or the user is not found HTTPException: If the token is invalid or the user is not found
""" """
@ -35,103 +38,108 @@ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = De
detail="Invalid credentials", detail="Invalid credentials",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
try: try:
# Decode the token # Decode the token
payload = jwt.decode( payload = jwt.decode(
token, token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM]
) )
# Extract token data # Extract token data
email: str = payload.get("sub") email: str = payload.get("sub")
if email is None: if email is None:
logger.warning("Token without email (sub)") logger.warning("Token without email (sub)")
raise credentials_exception raise credentials_exception
# Check if the token has expired # Check if the token has expired
exp = payload.get("exp") exp = payload.get("exp")
if exp is None or datetime.fromtimestamp(exp) < datetime.utcnow(): if exp is None or datetime.fromtimestamp(exp) < datetime.utcnow():
logger.warning(f"Token expired for {email}") logger.warning(f"Token expired for {email}")
raise credentials_exception raise credentials_exception
# Create TokenData object # Create TokenData object
token_data = TokenData( token_data = TokenData(
sub=email, sub=email,
exp=datetime.fromtimestamp(exp), exp=datetime.fromtimestamp(exp),
is_admin=payload.get("is_admin", False), is_admin=payload.get("is_admin", False),
client_id=payload.get("client_id") client_id=payload.get("client_id"),
) )
except JWTError as e: except JWTError as e:
logger.error(f"Error decoding JWT token: {str(e)}") logger.error(f"Error decoding JWT token: {str(e)}")
raise credentials_exception raise credentials_exception
# Search for user in the database # Search for user in the database
user = get_user_by_email(db, email=token_data.sub) user = get_user_by_email(db, email=token_data.sub)
if user is None: if user is None:
logger.warning(f"User not found for email: {token_data.sub}") logger.warning(f"User not found for email: {token_data.sub}")
raise credentials_exception raise credentials_exception
if not user.is_active: if not user.is_active:
logger.warning(f"Attempt to access inactive user: {user.email}") logger.warning(f"Attempt to access inactive user: {user.email}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
detail="Inactive user"
) )
return user return user
async def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
async def get_current_active_user(
current_user: User = Depends(get_current_user),
) -> User:
""" """
Check if the current user is active Check if the current user is active
Args: Args:
current_user: Current user current_user: Current user
Returns: Returns:
User: Current user if active User: Current user if active
Raises: Raises:
HTTPException: If the user is not active HTTPException: If the user is not active
""" """
if not current_user.is_active: if not current_user.is_active:
logger.warning(f"Attempt to access inactive user: {current_user.email}") logger.warning(f"Attempt to access inactive user: {current_user.email}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
detail="Inactive user"
) )
return current_user return current_user
async def get_current_admin_user(current_user: User = Depends(get_current_user)) -> User:
async def get_current_admin_user(
current_user: User = Depends(get_current_user),
) -> User:
""" """
Check if the current user is an administrator Check if the current user is an administrator
Args: Args:
current_user: Current user current_user: Current user
Returns: Returns:
User: Current user if administrator User: Current user if administrator
Raises: Raises:
HTTPException: If the user is not an administrator HTTPException: If the user is not an administrator
""" """
if not current_user.is_admin: if not current_user.is_admin:
logger.warning(f"Attempt to access admin by non-admin user: {current_user.email}") logger.warning(
f"Attempt to access admin by non-admin user: {current_user.email}"
)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. Restricted to administrators." detail="Access denied. Restricted to administrators.",
) )
return current_user return current_user
def create_access_token(user: User) -> str: def create_access_token(user: User) -> str:
""" """
Create a JWT access token for the user Create a JWT access token for the user
Args: Args:
user: User for which to create the token user: User for which to create the token
Returns: Returns:
str: JWT token str: JWT token
""" """
@ -140,10 +148,10 @@ def create_access_token(user: User) -> str:
"sub": user.email, "sub": user.email,
"is_admin": user.is_admin, "is_admin": user.is_admin,
} }
# Include client_id only if not administrator and client_id is set # Include client_id only if not administrator and client_id is set
if not user.is_admin and user.client_id: if not user.is_admin and user.client_id:
token_data["client_id"] = str(user.client_id) token_data["client_id"] = str(user.client_id)
# Create token # Create token
return create_jwt_token(token_data) return create_jwt_token(token_data)

View File

@ -11,6 +11,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]: def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]:
"""Search for a client by ID""" """Search for a client by ID"""
try: try:
@ -23,9 +24,10 @@ def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]:
logger.error(f"Error searching for client {client_id}: {str(e)}") logger.error(f"Error searching for client {client_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for client" detail="Error searching for client",
) )
def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]: def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]:
"""Search for all clients with pagination""" """Search for all clients with pagination"""
try: try:
@ -34,9 +36,10 @@ def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]:
logger.error(f"Error searching for clients: {str(e)}") logger.error(f"Error searching for clients: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for clients" detail="Error searching for clients",
) )
def create_client(db: Session, client: ClientCreate) -> Client: def create_client(db: Session, client: ClientCreate) -> Client:
"""Create a new client""" """Create a new client"""
try: try:
@ -51,19 +54,22 @@ def create_client(db: Session, client: ClientCreate) -> Client:
logger.error(f"Error creating client: {str(e)}") logger.error(f"Error creating client: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating client" detail="Error creating client",
) )
def update_client(db: Session, client_id: uuid.UUID, client: ClientCreate) -> Optional[Client]:
def update_client(
db: Session, client_id: uuid.UUID, client: ClientCreate
) -> Optional[Client]:
"""Updates an existing client""" """Updates an existing client"""
try: try:
db_client = get_client(db, client_id) db_client = get_client(db, client_id)
if not db_client: if not db_client:
return None return None
for key, value in client.model_dump().items(): for key, value in client.model_dump().items():
setattr(db_client, key, value) setattr(db_client, key, value)
db.commit() db.commit()
db.refresh(db_client) db.refresh(db_client)
logger.info(f"Client updated successfully: {client_id}") logger.info(f"Client updated successfully: {client_id}")
@ -73,16 +79,17 @@ def update_client(db: Session, client_id: uuid.UUID, client: ClientCreate) -> Op
logger.error(f"Error updating client {client_id}: {str(e)}") logger.error(f"Error updating client {client_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating client" detail="Error updating client",
) )
def delete_client(db: Session, client_id: uuid.UUID) -> bool: def delete_client(db: Session, client_id: uuid.UUID) -> bool:
"""Removes a client""" """Removes a client"""
try: try:
db_client = get_client(db, client_id) db_client = get_client(db, client_id)
if not db_client: if not db_client:
return False return False
db.delete(db_client) db.delete(db_client)
db.commit() db.commit()
logger.info(f"Client removed successfully: {client_id}") logger.info(f"Client removed successfully: {client_id}")
@ -92,18 +99,21 @@ def delete_client(db: Session, client_id: uuid.UUID) -> bool:
logger.error(f"Error removing client {client_id}: {str(e)}") logger.error(f"Error removing client {client_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error removing client" detail="Error removing client",
) )
def create_client_with_user(db: Session, client_data: ClientCreate, user_data: UserCreate) -> Tuple[Optional[Client], str]:
def create_client_with_user(
db: Session, client_data: ClientCreate, user_data: UserCreate
) -> Tuple[Optional[Client], str]:
""" """
Creates a new client with an associated user Creates a new client with an associated user
Args: Args:
db: Database session db: Database session
client_data: Client data to be created client_data: Client data to be created
user_data: User data to be created user_data: User data to be created
Returns: Returns:
Tuple[Optional[Client], str]: Tuple with the created client (or None in case of error) and status message Tuple[Optional[Client], str]: Tuple with the created client (or None in case of error) and status message
""" """
@ -112,27 +122,27 @@ def create_client_with_user(db: Session, client_data: ClientCreate, user_data: U
client = Client(**client_data.model_dump()) client = Client(**client_data.model_dump())
db.add(client) db.add(client)
db.flush() # Get client ID without committing the transaction db.flush() # Get client ID without committing the transaction
# Use client ID to create the associated user # Use client ID to create the associated user
user, message = create_user(db, user_data, is_admin=False, client_id=client.id) user, message = create_user(db, user_data, is_admin=False, client_id=client.id)
if not user: if not user:
# If there was an error creating the user, rollback # If there was an error creating the user, rollback
db.rollback() db.rollback()
logger.error(f"Error creating user for client: {message}") logger.error(f"Error creating user for client: {message}")
return None, f"Error creating user: {message}" return None, f"Error creating user: {message}"
# If everything went well, commit the transaction # If everything went well, commit the transaction
db.commit() db.commit()
logger.info(f"Client and user created successfully: {client.id}") logger.info(f"Client and user created successfully: {client.id}")
return client, "Client and user created successfully" return client, "Client and user created successfully"
except SQLAlchemyError as e: except SQLAlchemyError as e:
db.rollback() db.rollback()
logger.error(f"Error creating client with user: {str(e)}") logger.error(f"Error creating client with user: {str(e)}")
return None, f"Error creating client with user: {str(e)}" return None, f"Error creating client with user: {str(e)}"
except Exception as e: except Exception as e:
db.rollback() db.rollback()
logger.error(f"Unexpected error creating client with user: {str(e)}") logger.error(f"Unexpected error creating client with user: {str(e)}")
return None, f"Unexpected error: {str(e)}" return None, f"Unexpected error: {str(e)}"

View File

@ -9,6 +9,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]: def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]:
"""Search for a contact by ID""" """Search for a contact by ID"""
try: try:
@ -21,20 +22,30 @@ def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]:
logger.error(f"Error searching for contact {contact_id}: {str(e)}") logger.error(f"Error searching for contact {contact_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for contact" detail="Error searching for contact",
) )
def get_contacts_by_client(db: Session, client_id: uuid.UUID, skip: int = 0, limit: int = 100) -> List[Contact]:
def get_contacts_by_client(
db: Session, client_id: uuid.UUID, skip: int = 0, limit: int = 100
) -> List[Contact]:
"""Search for contacts of a client with pagination""" """Search for contacts of a client with pagination"""
try: try:
return db.query(Contact).filter(Contact.client_id == client_id).offset(skip).limit(limit).all() return (
db.query(Contact)
.filter(Contact.client_id == client_id)
.offset(skip)
.limit(limit)
.all()
)
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Error searching for contacts of client {client_id}: {str(e)}") logger.error(f"Error searching for contacts of client {client_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for contacts" detail="Error searching for contacts",
) )
def create_contact(db: Session, contact: ContactCreate) -> Contact: def create_contact(db: Session, contact: ContactCreate) -> Contact:
"""Create a new contact""" """Create a new contact"""
try: try:
@ -49,19 +60,22 @@ def create_contact(db: Session, contact: ContactCreate) -> Contact:
logger.error(f"Error creating contact: {str(e)}") logger.error(f"Error creating contact: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating contact" detail="Error creating contact",
) )
def update_contact(db: Session, contact_id: uuid.UUID, contact: ContactCreate) -> Optional[Contact]:
def update_contact(
db: Session, contact_id: uuid.UUID, contact: ContactCreate
) -> Optional[Contact]:
"""Update an existing contact""" """Update an existing contact"""
try: try:
db_contact = get_contact(db, contact_id) db_contact = get_contact(db, contact_id)
if not db_contact: if not db_contact:
return None return None
for key, value in contact.model_dump().items(): for key, value in contact.model_dump().items():
setattr(db_contact, key, value) setattr(db_contact, key, value)
db.commit() db.commit()
db.refresh(db_contact) db.refresh(db_contact)
logger.info(f"Contact updated successfully: {contact_id}") logger.info(f"Contact updated successfully: {contact_id}")
@ -71,16 +85,17 @@ def update_contact(db: Session, contact_id: uuid.UUID, contact: ContactCreate) -
logger.error(f"Error updating contact {contact_id}: {str(e)}") logger.error(f"Error updating contact {contact_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating contact" detail="Error updating contact",
) )
def delete_contact(db: Session, contact_id: uuid.UUID) -> bool: def delete_contact(db: Session, contact_id: uuid.UUID) -> bool:
"""Remove a contact""" """Remove a contact"""
try: try:
db_contact = get_contact(db, contact_id) db_contact = get_contact(db, contact_id)
if not db_contact: if not db_contact:
return False return False
db.delete(db_contact) db.delete(db_contact)
db.commit() db.commit()
logger.info(f"Contact removed successfully: {contact_id}") logger.info(f"Contact removed successfully: {contact_id}")
@ -90,5 +105,5 @@ def delete_contact(db: Session, contact_id: uuid.UUID) -> bool:
logger.error(f"Error removing contact {contact_id}: {str(e)}") logger.error(f"Error removing contact {contact_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error removing contact" detail="Error removing contact",
) )

View File

@ -6,6 +6,7 @@ from src.utils.logger import setup_logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
class CustomToolBuilder: class CustomToolBuilder:
def __init__(self): def __init__(self):
self.tools = [] self.tools = []
@ -53,7 +54,9 @@ class CustomToolBuilder:
# Adds default values to query params if they are not present # Adds default values to query params if they are not present
for param, value in values.items(): for param, value in values.items():
if param not in query_params and param not in parameters.get("path_params", {}): if param not in query_params and param not in parameters.get(
"path_params", {}
):
query_params[param] = value query_params[param] = value
# Processa body parameters # Processa body parameters
@ -64,7 +67,11 @@ class CustomToolBuilder:
# Adds default values to body if they are not present # Adds default values to body if they are not present
for param, value in values.items(): for param, value in values.items():
if param not in body_data and param not in query_params and param not in parameters.get("path_params", {}): if (
param not in body_data
and param not in query_params
and param not in parameters.get("path_params", {})
):
body_data[param] = value body_data[param] = value
# Makes the HTTP request # Makes the HTTP request
@ -74,7 +81,7 @@ class CustomToolBuilder:
headers=processed_headers, headers=processed_headers,
params=query_params, params=query_params,
json=body_data if body_data else None, json=body_data if body_data else None,
timeout=error_handling.get("timeout", 30) timeout=error_handling.get("timeout", 30),
) )
if response.status_code >= 400: if response.status_code >= 400:
@ -87,30 +94,34 @@ class CustomToolBuilder:
except Exception as e: except Exception as e:
logger.error(f"Error executing tool {name}: {str(e)}") logger.error(f"Error executing tool {name}: {str(e)}")
return json.dumps(error_handling.get("fallback_response", { return json.dumps(
"error": "tool_execution_error", error_handling.get(
"message": str(e) "fallback_response",
})) {"error": "tool_execution_error", "message": str(e)},
)
)
# Adds dynamic docstring based on the configuration # Adds dynamic docstring based on the configuration
param_docs = [] param_docs = []
# Adds path parameters # Adds path parameters
for param, value in parameters.get("path_params", {}).items(): for param, value in parameters.get("path_params", {}).items():
param_docs.append(f"{param}: {value}") param_docs.append(f"{param}: {value}")
# Adds query parameters # Adds query parameters
for param, value in parameters.get("query_params", {}).items(): for param, value in parameters.get("query_params", {}).items():
if isinstance(value, list): if isinstance(value, list):
param_docs.append(f"{param}: List[{', '.join(value)}]") param_docs.append(f"{param}: List[{', '.join(value)}]")
else: else:
param_docs.append(f"{param}: {value}") param_docs.append(f"{param}: {value}")
# Adds body parameters # Adds body parameters
for param, param_config in parameters.get("body_params", {}).items(): for param, param_config in parameters.get("body_params", {}).items():
required = "Required" if param_config.get("required", False) else "Optional" required = "Required" if param_config.get("required", False) else "Optional"
param_docs.append(f"{param} ({param_config['type']}, {required}): {param_config['description']}") param_docs.append(
f"{param} ({param_config['type']}, {required}): {param_config['description']}"
)
# Adds default values # Adds default values
if values: if values:
param_docs.append("\nDefault values:") param_docs.append("\nDefault values:")
@ -119,10 +130,10 @@ class CustomToolBuilder:
http_tool.__doc__ = f""" http_tool.__doc__ = f"""
{description} {description}
Parameters: Parameters:
{chr(10).join(param_docs)} {chr(10).join(param_docs)}
Returns: Returns:
String containing the response in JSON format String containing the response in JSON format
""" """
@ -140,4 +151,4 @@ class CustomToolBuilder:
for http_tool_config in tools_config.get("http_tools", []): for http_tool_config in tools_config.get("http_tools", []):
self.tools.append(self._create_http_tool(http_tool_config)) self.tools.append(self._create_http_tool(http_tool_config))
return self.tools return self.tools

View File

@ -16,17 +16,18 @@ os.makedirs(templates_dir, exist_ok=True)
# Configure Jinja2 with the templates directory # Configure Jinja2 with the templates directory
env = Environment( env = Environment(
loader=FileSystemLoader(templates_dir), loader=FileSystemLoader(templates_dir),
autoescape=select_autoescape(['html', 'xml']) autoescape=select_autoescape(["html", "xml"]),
) )
def _render_template(template_name: str, context: dict) -> str: def _render_template(template_name: str, context: dict) -> str:
""" """
Render a template with the provided data Render a template with the provided data
Args: Args:
template_name: Template file name template_name: Template file name
context: Data to render in the template context: Data to render in the template
Returns: Returns:
str: Rendered HTML str: Rendered HTML
""" """
@ -37,14 +38,15 @@ def _render_template(template_name: str, context: dict) -> str:
logger.error(f"Error rendering template '{template_name}': {str(e)}") logger.error(f"Error rendering template '{template_name}': {str(e)}")
return f"<p>Could not display email content. Please access {context.get('verification_link', '') or context.get('reset_link', '')}</p>" return f"<p>Could not display email content. Please access {context.get('verification_link', '') or context.get('reset_link', '')}</p>"
def send_verification_email(email: str, token: str) -> bool: def send_verification_email(email: str, token: str) -> bool:
""" """
Send a verification email to the user Send a verification email to the user
Args: Args:
email: Recipient's email email: Recipient's email
token: Email verification token token: Email verification token
Returns: Returns:
bool: True if the email was sent successfully, False otherwise bool: True if the email was sent successfully, False otherwise
""" """
@ -53,39 +55,47 @@ def send_verification_email(email: str, token: str) -> bool:
from_email = Email(settings.EMAIL_FROM) from_email = Email(settings.EMAIL_FROM)
to_email = To(email) to_email = To(email)
subject = "Email Verification - Evo AI" subject = "Email Verification - Evo AI"
verification_link = f"{settings.APP_URL}/auth/verify-email/{token}" verification_link = f"{settings.APP_URL}/auth/verify-email/{token}"
html_content = _render_template('verification_email', { html_content = _render_template(
'verification_link': verification_link, "verification_email",
'user_name': email.split('@')[0], # Use part of the email as temporary name {
'current_year': datetime.now().year "verification_link": verification_link,
}) "user_name": email.split("@")[
0
], # Use part of the email as temporary name
"current_year": datetime.now().year,
},
)
content = Content("text/html", html_content) content = Content("text/html", html_content)
mail = Mail(from_email, to_email, subject, content) mail = Mail(from_email, to_email, subject, content)
response = sg.client.mail.send.post(request_body=mail.get()) response = sg.client.mail.send.post(request_body=mail.get())
if response.status_code >= 200 and response.status_code < 300: if response.status_code >= 200 and response.status_code < 300:
logger.info(f"Verification email sent to {email}") logger.info(f"Verification email sent to {email}")
return True return True
else: else:
logger.error(f"Failed to send verification email to {email}. Status: {response.status_code}") logger.error(
f"Failed to send verification email to {email}. Status: {response.status_code}"
)
return False return False
except Exception as e: except Exception as e:
logger.error(f"Error sending verification email to {email}: {str(e)}") logger.error(f"Error sending verification email to {email}: {str(e)}")
return False return False
def send_password_reset_email(email: str, token: str) -> bool: def send_password_reset_email(email: str, token: str) -> bool:
""" """
Send a password reset email to the user Send a password reset email to the user
Args: Args:
email: Recipient's email email: Recipient's email
token: Password reset token token: Password reset token
Returns: Returns:
bool: True if the email was sent successfully, False otherwise bool: True if the email was sent successfully, False otherwise
""" """
@ -94,39 +104,47 @@ def send_password_reset_email(email: str, token: str) -> bool:
from_email = Email(settings.EMAIL_FROM) from_email = Email(settings.EMAIL_FROM)
to_email = To(email) to_email = To(email)
subject = "Password Reset - Evo AI" subject = "Password Reset - Evo AI"
reset_link = f"{settings.APP_URL}/reset-password?token={token}" reset_link = f"{settings.APP_URL}/reset-password?token={token}"
html_content = _render_template('password_reset', { html_content = _render_template(
'reset_link': reset_link, "password_reset",
'user_name': email.split('@')[0], # Use part of the email as temporary name {
'current_year': datetime.now().year "reset_link": reset_link,
}) "user_name": email.split("@")[
0
], # Use part of the email as temporary name
"current_year": datetime.now().year,
},
)
content = Content("text/html", html_content) content = Content("text/html", html_content)
mail = Mail(from_email, to_email, subject, content) mail = Mail(from_email, to_email, subject, content)
response = sg.client.mail.send.post(request_body=mail.get()) response = sg.client.mail.send.post(request_body=mail.get())
if response.status_code >= 200 and response.status_code < 300: if response.status_code >= 200 and response.status_code < 300:
logger.info(f"Password reset email sent to {email}") logger.info(f"Password reset email sent to {email}")
return True return True
else: else:
logger.error(f"Failed to send password reset email to {email}. Status: {response.status_code}") logger.error(
f"Failed to send password reset email to {email}. Status: {response.status_code}"
)
return False return False
except Exception as e: except Exception as e:
logger.error(f"Error sending password reset email to {email}: {str(e)}") logger.error(f"Error sending password reset email to {email}: {str(e)}")
return False return False
def send_welcome_email(email: str, user_name: str = None) -> bool: def send_welcome_email(email: str, user_name: str = None) -> bool:
""" """
Send a welcome email to the user after verification Send a welcome email to the user after verification
Args: Args:
email: Recipient's email email: Recipient's email
user_name: User's name (optional) user_name: User's name (optional)
Returns: Returns:
bool: True if the email was sent successfully, False otherwise bool: True if the email was sent successfully, False otherwise
""" """
@ -135,41 +153,49 @@ def send_welcome_email(email: str, user_name: str = None) -> bool:
from_email = Email(settings.EMAIL_FROM) from_email = Email(settings.EMAIL_FROM)
to_email = To(email) to_email = To(email)
subject = "Welcome to Evo AI" subject = "Welcome to Evo AI"
dashboard_link = f"{settings.APP_URL}/dashboard" dashboard_link = f"{settings.APP_URL}/dashboard"
html_content = _render_template('welcome_email', { html_content = _render_template(
'dashboard_link': dashboard_link, "welcome_email",
'user_name': user_name or email.split('@')[0], {
'current_year': datetime.now().year "dashboard_link": dashboard_link,
}) "user_name": user_name or email.split("@")[0],
"current_year": datetime.now().year,
},
)
content = Content("text/html", html_content) content = Content("text/html", html_content)
mail = Mail(from_email, to_email, subject, content) mail = Mail(from_email, to_email, subject, content)
response = sg.client.mail.send.post(request_body=mail.get()) response = sg.client.mail.send.post(request_body=mail.get())
if response.status_code >= 200 and response.status_code < 300: if response.status_code >= 200 and response.status_code < 300:
logger.info(f"Welcome email sent to {email}") logger.info(f"Welcome email sent to {email}")
return True return True
else: else:
logger.error(f"Failed to send welcome email to {email}. Status: {response.status_code}") logger.error(
f"Failed to send welcome email to {email}. Status: {response.status_code}"
)
return False return False
except Exception as e: except Exception as e:
logger.error(f"Error sending welcome email to {email}: {str(e)}") logger.error(f"Error sending welcome email to {email}: {str(e)}")
return False return False
def send_account_locked_email(email: str, reset_token: str, failed_attempts: int, time_period: str) -> bool:
def send_account_locked_email(
email: str, reset_token: str, failed_attempts: int, time_period: str
) -> bool:
""" """
Send an email informing that the account has been locked after login attempts Send an email informing that the account has been locked after login attempts
Args: Args:
email: Recipient's email email: Recipient's email
reset_token: Token to reset the password reset_token: Token to reset the password
failed_attempts: Number of failed attempts failed_attempts: Number of failed attempts
time_period: Time period of the attempts time_period: Time period of the attempts
Returns: Returns:
bool: True if the email was sent successfully, False otherwise bool: True if the email was sent successfully, False otherwise
""" """
@ -178,29 +204,34 @@ def send_account_locked_email(email: str, reset_token: str, failed_attempts: int
from_email = Email(settings.EMAIL_FROM) from_email = Email(settings.EMAIL_FROM)
to_email = To(email) to_email = To(email)
subject = "Security Alert - Account Locked" subject = "Security Alert - Account Locked"
reset_link = f"{settings.APP_URL}/reset-password?token={reset_token}" reset_link = f"{settings.APP_URL}/reset-password?token={reset_token}"
html_content = _render_template('account_locked', { html_content = _render_template(
'reset_link': reset_link, "account_locked",
'user_name': email.split('@')[0], {
'failed_attempts': failed_attempts, "reset_link": reset_link,
'time_period': time_period, "user_name": email.split("@")[0],
'current_year': datetime.now().year "failed_attempts": failed_attempts,
}) "time_period": time_period,
"current_year": datetime.now().year,
},
)
content = Content("text/html", html_content) content = Content("text/html", html_content)
mail = Mail(from_email, to_email, subject, content) mail = Mail(from_email, to_email, subject, content)
response = sg.client.mail.send.post(request_body=mail.get()) response = sg.client.mail.send.post(request_body=mail.get())
if response.status_code >= 200 and response.status_code < 300: if response.status_code >= 200 and response.status_code < 300:
logger.info(f"Account locked email sent to {email}") logger.info(f"Account locked email sent to {email}")
return True return True
else: else:
logger.error(f"Failed to send account locked email to {email}. Status: {response.status_code}") logger.error(
f"Failed to send account locked email to {email}. Status: {response.status_code}"
)
return False return False
except Exception as e: except Exception as e:
logger.error(f"Error sending account locked email to {email}: {str(e)}") logger.error(f"Error sending account locked email to {email}: {str(e)}")
return False return False

View File

@ -9,6 +9,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]: def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]:
"""Search for an MCP server by ID""" """Search for an MCP server by ID"""
try: try:
@ -21,9 +22,10 @@ def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]:
logger.error(f"Error searching for MCP server {server_id}: {str(e)}") logger.error(f"Error searching for MCP server {server_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for MCP server" detail="Error searching for MCP server",
) )
def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPServer]: def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPServer]:
"""Search for all MCP servers with pagination""" """Search for all MCP servers with pagination"""
try: try:
@ -32,9 +34,10 @@ def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPSer
logger.error(f"Error searching for MCP servers: {str(e)}") logger.error(f"Error searching for MCP servers: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for MCP servers" detail="Error searching for MCP servers",
) )
def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer: def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer:
"""Create a new MCP server""" """Create a new MCP server"""
try: try:
@ -49,19 +52,22 @@ def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer:
logger.error(f"Error creating MCP server: {str(e)}") logger.error(f"Error creating MCP server: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating MCP server" detail="Error creating MCP server",
) )
def update_mcp_server(db: Session, server_id: uuid.UUID, server: MCPServerCreate) -> Optional[MCPServer]:
def update_mcp_server(
db: Session, server_id: uuid.UUID, server: MCPServerCreate
) -> Optional[MCPServer]:
"""Update an existing MCP server""" """Update an existing MCP server"""
try: try:
db_server = get_mcp_server(db, server_id) db_server = get_mcp_server(db, server_id)
if not db_server: if not db_server:
return None return None
for key, value in server.model_dump().items(): for key, value in server.model_dump().items():
setattr(db_server, key, value) setattr(db_server, key, value)
db.commit() db.commit()
db.refresh(db_server) db.refresh(db_server)
logger.info(f"MCP server updated successfully: {server_id}") logger.info(f"MCP server updated successfully: {server_id}")
@ -71,16 +77,17 @@ def update_mcp_server(db: Session, server_id: uuid.UUID, server: MCPServerCreate
logger.error(f"Error updating MCP server {server_id}: {str(e)}") logger.error(f"Error updating MCP server {server_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating MCP server" detail="Error updating MCP server",
) )
def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool: def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool:
"""Remove an MCP server""" """Remove an MCP server"""
try: try:
db_server = get_mcp_server(db, server_id) db_server = get_mcp_server(db, server_id)
if not db_server: if not db_server:
return False return False
db.delete(db_server) db.delete(db_server)
db.commit() db.commit()
logger.info(f"MCP server removed successfully: {server_id}") logger.info(f"MCP server removed successfully: {server_id}")
@ -90,5 +97,5 @@ def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool:
logger.error(f"Error removing MCP server {server_id}: {str(e)}") logger.error(f"Error removing MCP server {server_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error removing MCP server" detail="Error removing MCP server",
) )

View File

@ -12,26 +12,28 @@ from sqlalchemy.orm import Session
logger = setup_logger(__name__) logger = setup_logger(__name__)
class MCPService: class MCPService:
def __init__(self): def __init__(self):
self.tools = [] self.tools = []
self.exit_stack = AsyncExitStack() self.exit_stack = AsyncExitStack()
async def _connect_to_mcp_server(self, server_config: Dict[str, Any]) -> Tuple[List[Any], Optional[AsyncExitStack]]: async def _connect_to_mcp_server(
self, server_config: Dict[str, Any]
) -> Tuple[List[Any], Optional[AsyncExitStack]]:
"""Connect to a specific MCP server and return its tools.""" """Connect to a specific MCP server and return its tools."""
try: try:
# Determines the type of server (local or remote) # Determines the type of server (local or remote)
if "url" in server_config: if "url" in server_config:
# Remote server (SSE) # Remote server (SSE)
connection_params = SseServerParams( connection_params = SseServerParams(
url=server_config["url"], url=server_config["url"], headers=server_config.get("headers", {})
headers=server_config.get("headers", {})
) )
else: else:
# Local server (Stdio) # Local server (Stdio)
command = server_config.get("command", "npx") command = server_config.get("command", "npx")
args = server_config.get("args", []) args = server_config.get("args", [])
# Adds environment variables if specified # Adds environment variables if specified
env = server_config.get("env", {}) env = server_config.get("env", {})
if env: if env:
@ -39,9 +41,7 @@ class MCPService:
os.environ[key] = value os.environ[key] = value
connection_params = StdioServerParameters( connection_params = StdioServerParameters(
command=command, command=command, args=args, env=env
args=args,
env=env
) )
tools, exit_stack = await MCPToolset.from_server( tools, exit_stack = await MCPToolset.from_server(
@ -73,8 +73,10 @@ class MCPService:
logger.warning(f"Removed {removed_count} incompatible tools.") logger.warning(f"Removed {removed_count} incompatible tools.")
return filtered_tools return filtered_tools
def _filter_tools_by_agent(self, tools: List[Any], agent_tools: List[str]) -> List[Any]: def _filter_tools_by_agent(
self, tools: List[Any], agent_tools: List[str]
) -> List[Any]:
"""Filters tools compatible with the agent.""" """Filters tools compatible with the agent."""
filtered_tools = [] filtered_tools = []
for tool in tools: for tool in tools:
@ -83,7 +85,9 @@ class MCPService:
filtered_tools.append(tool) filtered_tools.append(tool)
return filtered_tools return filtered_tools
async def build_tools(self, mcp_config: Dict[str, Any], db: Session) -> Tuple[List[Any], AsyncExitStack]: async def build_tools(
self, mcp_config: Dict[str, Any], db: Session
) -> Tuple[List[Any], AsyncExitStack]:
"""Builds a list of tools from multiple MCP servers.""" """Builds a list of tools from multiple MCP servers."""
self.tools = [] self.tools = []
self.exit_stack = AsyncExitStack() self.exit_stack = AsyncExitStack()
@ -92,23 +96,25 @@ class MCPService:
for server in mcp_config.get("mcp_servers", []): for server in mcp_config.get("mcp_servers", []):
try: try:
# Search for the MCP server in the database # Search for the MCP server in the database
mcp_server = get_mcp_server(db, server['id']) mcp_server = get_mcp_server(db, server["id"])
if not mcp_server: if not mcp_server:
logger.warning(f"Servidor MCP não encontrado: {server['id']}") logger.warning(f"Servidor MCP não encontrado: {server['id']}")
continue continue
# Prepares the server configuration # Prepares the server configuration
server_config = mcp_server.config_json.copy() server_config = mcp_server.config_json.copy()
# Replaces the environment variables in the config_json # Replaces the environment variables in the config_json
if 'env' in server_config: if "env" in server_config:
for key, value in server_config['env'].items(): for key, value in server_config["env"].items():
if value.startswith('env@@'): if value.startswith("env@@"):
env_key = value.replace('env@@', '') env_key = value.replace("env@@", "")
if env_key in server.get('envs', {}): if env_key in server.get("envs", {}):
server_config['env'][key] = server['envs'][env_key] server_config["env"][key] = server["envs"][env_key]
else: else:
logger.warning(f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}") logger.warning(
f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}"
)
continue continue
logger.info(f"Connecting to MCP server: {mcp_server.name}") logger.info(f"Connecting to MCP server: {mcp_server.name}")
@ -117,22 +123,30 @@ class MCPService:
if tools and exit_stack: if tools and exit_stack:
# Filters incompatible tools # Filters incompatible tools
filtered_tools = self._filter_incompatible_tools(tools) filtered_tools = self._filter_incompatible_tools(tools)
# Filters tools compatible with the agent # Filters tools compatible with the agent
agent_tools = server.get('tools', []) agent_tools = server.get("tools", [])
filtered_tools = self._filter_tools_by_agent(filtered_tools, agent_tools) filtered_tools = self._filter_tools_by_agent(
filtered_tools, agent_tools
)
self.tools.extend(filtered_tools) self.tools.extend(filtered_tools)
# Registers the exit_stack with the AsyncExitStack # Registers the exit_stack with the AsyncExitStack
await self.exit_stack.enter_async_context(exit_stack) await self.exit_stack.enter_async_context(exit_stack)
logger.info(f"Connected successfully. Added {len(filtered_tools)} tools.") logger.info(
f"Connected successfully. Added {len(filtered_tools)} tools."
)
else: else:
logger.warning(f"Failed to connect or no tools available for {mcp_server.name}") logger.warning(
f"Failed to connect or no tools available for {mcp_server.name}"
)
except Exception as e: except Exception as e:
logger.error(f"Error connecting to MCP server {server['id']}: {e}") logger.error(f"Error connecting to MCP server {server['id']}: {e}")
continue continue
logger.info(f"MCP Toolset created successfully. Total of {len(self.tools)} tools.") logger.info(
f"MCP Toolset created successfully. Total of {len(self.tools)} tools."
)
return self.tools, self.exit_stack return self.tools, self.exit_stack

View File

@ -0,0 +1,9 @@
from src.config.settings import settings
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.sessions import DatabaseSessionService
from google.adk.memory import InMemoryMemoryService
# Initialize service instances
session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING)
artifacts_service = InMemoryArtifactService()
memory_service = InMemoryMemoryService()

View File

@ -66,7 +66,7 @@ def get_session_by_id(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid session ID. Expected format: app_name_user_id", detail="Invalid session ID. Expected format: app_name_user_id",
) )
parts = session_id.split("_", 1) parts = session_id.split("_", 1)
if len(parts) != 2: if len(parts) != 2:
logger.error(f"Invalid session ID format: {session_id}") logger.error(f"Invalid session ID format: {session_id}")
@ -74,22 +74,22 @@ def get_session_by_id(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid session ID format. Expected format: app_name_user_id", detail="Invalid session ID format. Expected format: app_name_user_id",
) )
user_id, app_name = parts user_id, app_name = parts
session = session_service.get_session( session = session_service.get_session(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,
) )
if session is None: if session is None:
logger.error(f"Session not found: {session_id}") logger.error(f"Session not found: {session_id}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Session not found: {session_id}", detail=f"Session not found: {session_id}",
) )
return session return session
except Exception as e: except Exception as e:
logger.error(f"Error searching for session {session_id}: {str(e)}") logger.error(f"Error searching for session {session_id}: {str(e)}")
@ -106,7 +106,7 @@ def delete_session(session_service: DatabaseSessionService, session_id: str) ->
try: try:
session = get_session_by_id(session_service, session_id) session = get_session_by_id(session_service, session_id)
# If we get here, the session exists (get_session_by_id already validates) # If we get here, the session exists (get_session_by_id already validates)
session_service.delete_session( session_service.delete_session(
app_name=session.app_name, app_name=session.app_name,
user_id=session.user_id, user_id=session.user_id,
@ -131,10 +131,10 @@ def get_session_events(
try: try:
session = get_session_by_id(session_service, session_id) session = get_session_by_id(session_service, session_id)
# If we get here, the session exists (get_session_by_id already validates) # If we get here, the session exists (get_session_by_id already validates)
if not hasattr(session, 'events') or session.events is None: if not hasattr(session, "events") or session.events is None:
return [] return []
return session.events return session.events
except HTTPException: except HTTPException:
# Passes HTTP exceptions from get_session_by_id # Passes HTTP exceptions from get_session_by_id

View File

@ -9,6 +9,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]: def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]:
"""Search for a tool by ID""" """Search for a tool by ID"""
try: try:
@ -21,9 +22,10 @@ def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]:
logger.error(f"Error searching for tool {tool_id}: {str(e)}") logger.error(f"Error searching for tool {tool_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for tool" detail="Error searching for tool",
) )
def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]: def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]:
"""Search for all tools with pagination""" """Search for all tools with pagination"""
try: try:
@ -32,9 +34,10 @@ def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]:
logger.error(f"Error searching for tools: {str(e)}") logger.error(f"Error searching for tools: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for tools" detail="Error searching for tools",
) )
def create_tool(db: Session, tool: ToolCreate) -> Tool: def create_tool(db: Session, tool: ToolCreate) -> Tool:
"""Creates a new tool""" """Creates a new tool"""
try: try:
@ -49,19 +52,20 @@ def create_tool(db: Session, tool: ToolCreate) -> Tool:
logger.error(f"Error creating tool: {str(e)}") logger.error(f"Error creating tool: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating tool" detail="Error creating tool",
) )
def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[Tool]: def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[Tool]:
"""Updates an existing tool""" """Updates an existing tool"""
try: try:
db_tool = get_tool(db, tool_id) db_tool = get_tool(db, tool_id)
if not db_tool: if not db_tool:
return None return None
for key, value in tool.model_dump().items(): for key, value in tool.model_dump().items():
setattr(db_tool, key, value) setattr(db_tool, key, value)
db.commit() db.commit()
db.refresh(db_tool) db.refresh(db_tool)
logger.info(f"Tool updated successfully: {tool_id}") logger.info(f"Tool updated successfully: {tool_id}")
@ -71,16 +75,17 @@ def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[T
logger.error(f"Error updating tool {tool_id}: {str(e)}") logger.error(f"Error updating tool {tool_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating tool" detail="Error updating tool",
) )
def delete_tool(db: Session, tool_id: uuid.UUID) -> bool: def delete_tool(db: Session, tool_id: uuid.UUID) -> bool:
"""Remove a tool""" """Remove a tool"""
try: try:
db_tool = get_tool(db, tool_id) db_tool = get_tool(db, tool_id)
if not db_tool: if not db_tool:
return False return False
db.delete(db_tool) db.delete(db_tool)
db.commit() db.commit()
logger.info(f"Tool removed successfully: {tool_id}") logger.info(f"Tool removed successfully: {tool_id}")
@ -90,5 +95,5 @@ def delete_tool(db: Session, tool_id: uuid.UUID) -> bool:
logger.error(f"Error removing tool {tool_id}: {str(e)}") logger.error(f"Error removing tool {tool_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error removing tool" detail="Error removing tool",
) )

View File

@ -3,7 +3,10 @@ from sqlalchemy.exc import SQLAlchemyError
from src.models.models import User, Client from src.models.models import User, Client
from src.schemas.user import UserCreate from src.schemas.user import UserCreate
from src.utils.security import get_password_hash, verify_password, generate_token from src.utils.security import get_password_hash, verify_password, generate_token
from src.services.email_service import send_verification_email, send_password_reset_email from src.services.email_service import (
send_verification_email,
send_password_reset_email,
)
from datetime import datetime, timedelta from datetime import datetime, timedelta
import uuid import uuid
import logging import logging
@ -11,16 +14,22 @@ from typing import Optional, Tuple
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, client_id: Optional[uuid.UUID] = None) -> Tuple[Optional[User], str]:
def create_user(
db: Session,
user_data: UserCreate,
is_admin: bool = False,
client_id: Optional[uuid.UUID] = None,
) -> Tuple[Optional[User], str]:
""" """
Creates a new user in the system Creates a new user in the system
Args: Args:
db: Database session db: Database session
user_data: User data to be created user_data: User data to be created
is_admin: If the user is an administrator is_admin: If the user is an administrator
client_id: Associated client ID (optional, a new one will be created if not provided) client_id: Associated client ID (optional, a new one will be created if not provided)
Returns: Returns:
Tuple[Optional[User], str]: Tuple with the created user (or None in case of error) and status message Tuple[Optional[User], str]: Tuple with the created user (or None in case of error) and status message
""" """
@ -28,17 +37,19 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
# Check if email already exists # Check if email already exists
db_user = db.query(User).filter(User.email == user_data.email).first() db_user = db.query(User).filter(User.email == user_data.email).first()
if db_user: if db_user:
logger.warning(f"Attempt to register with existing email: {user_data.email}") logger.warning(
f"Attempt to register with existing email: {user_data.email}"
)
return None, "Email already registered" return None, "Email already registered"
# Create verification token # Create verification token
verification_token = generate_token() verification_token = generate_token()
token_expiry = datetime.utcnow() + timedelta(hours=24) token_expiry = datetime.utcnow() + timedelta(hours=24)
# Start transaction # Start transaction
user = None user = None
local_client_id = client_id local_client_id = client_id
try: try:
# If not admin and no client_id, create an associated client # If not admin and no client_id, create an associated client
if not is_admin and local_client_id is None: if not is_admin and local_client_id is None:
@ -46,7 +57,7 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
db.add(client) db.add(client)
db.flush() # Get the client ID db.flush() # Get the client ID
local_client_id = client.id local_client_id = client.id
# Create user # Create user
user = User( user = User(
email=user_data.email, email=user_data.email,
@ -56,52 +67,56 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
is_active=False, # Inactive until email is verified is_active=False, # Inactive until email is verified
email_verified=False, email_verified=False,
verification_token=verification_token, verification_token=verification_token,
verification_token_expiry=token_expiry verification_token_expiry=token_expiry,
) )
db.add(user) db.add(user)
db.commit() db.commit()
# Send verification email # Send verification email
email_sent = send_verification_email(user.email, verification_token) email_sent = send_verification_email(user.email, verification_token)
if not email_sent: if not email_sent:
logger.error(f"Failed to send verification email to {user.email}") logger.error(f"Failed to send verification email to {user.email}")
# We don't do rollback here, we just log the error # We don't do rollback here, we just log the error
logger.info(f"User created successfully: {user.email}") logger.info(f"User created successfully: {user.email}")
return user, "User created successfully. Check your email to activate your account." return (
user,
"User created successfully. Check your email to activate your account.",
)
except SQLAlchemyError as e: except SQLAlchemyError as e:
db.rollback() db.rollback()
logger.error(f"Error creating user: {str(e)}") logger.error(f"Error creating user: {str(e)}")
return None, f"Error creating user: {str(e)}" return None, f"Error creating user: {str(e)}"
except Exception as e: except Exception as e:
logger.error(f"Unexpected error creating user: {str(e)}") logger.error(f"Unexpected error creating user: {str(e)}")
return None, f"Unexpected error: {str(e)}" return None, f"Unexpected error: {str(e)}"
def verify_email(db: Session, token: str) -> Tuple[bool, str]: def verify_email(db: Session, token: str) -> Tuple[bool, str]:
""" """
Verify the user's email using the provided token Verify the user's email using the provided token
Args: Args:
db: Database session db: Database session
token: Verification token token: Verification token
Returns: Returns:
Tuple[bool, str]: Tuple with verification status and message Tuple[bool, str]: Tuple with verification status and message
""" """
try: try:
# Search for user by token # Search for user by token
user = db.query(User).filter(User.verification_token == token).first() user = db.query(User).filter(User.verification_token == token).first()
if not user: if not user:
logger.warning(f"Attempt to verify with invalid token: {token}") logger.warning(f"Attempt to verify with invalid token: {token}")
return False, "Invalid verification token" return False, "Invalid verification token"
# Check if the token has expired # Check if the token has expired
now = datetime.utcnow() now = datetime.utcnow()
expiry = user.verification_token_expiry expiry = user.verification_token_expiry
# Ensure both dates are of the same type (aware or naive) # Ensure both dates are of the same type (aware or naive)
if expiry.tzinfo is not None and now.tzinfo is None: if expiry.tzinfo is not None and now.tzinfo is None:
# If expiry has timezone and now doesn't, convert now to have timezone # If expiry has timezone and now doesn't, convert now to have timezone
@ -109,180 +124,201 @@ def verify_email(db: Session, token: str) -> Tuple[bool, str]:
elif now.tzinfo is not None and expiry.tzinfo is None: elif now.tzinfo is not None and expiry.tzinfo is None:
# If now has timezone and expiry doesn't, convert expiry to have timezone # If now has timezone and expiry doesn't, convert expiry to have timezone
expiry = expiry.replace(tzinfo=now.tzinfo) expiry = expiry.replace(tzinfo=now.tzinfo)
if expiry < now: if expiry < now:
logger.warning(f"Attempt to verify with expired token for user: {user.email}") logger.warning(
f"Attempt to verify with expired token for user: {user.email}"
)
return False, "Verification token expired" return False, "Verification token expired"
# Update user # Update user
user.email_verified = True user.email_verified = True
user.is_active = True user.is_active = True
user.verification_token = None user.verification_token = None
user.verification_token_expiry = None user.verification_token_expiry = None
db.commit() db.commit()
logger.info(f"Email verified successfully for user: {user.email}") logger.info(f"Email verified successfully for user: {user.email}")
return True, "Email verified successfully. Your account is active." return True, "Email verified successfully. Your account is active."
except SQLAlchemyError as e: except SQLAlchemyError as e:
db.rollback() db.rollback()
logger.error(f"Error verifying email: {str(e)}") logger.error(f"Error verifying email: {str(e)}")
return False, f"Error verifying email: {str(e)}" return False, f"Error verifying email: {str(e)}"
except Exception as e: except Exception as e:
logger.error(f"Unexpected error verifying email: {str(e)}") logger.error(f"Unexpected error verifying email: {str(e)}")
return False, f"Unexpected error: {str(e)}" return False, f"Unexpected error: {str(e)}"
def resend_verification(db: Session, email: str) -> Tuple[bool, str]: def resend_verification(db: Session, email: str) -> Tuple[bool, str]:
""" """
Resend the verification email Resend the verification email
Args: Args:
db: Database session db: Database session
email: User email email: User email
Returns: Returns:
Tuple[bool, str]: Tuple with operation status and message Tuple[bool, str]: Tuple with operation status and message
""" """
try: try:
# Search for user by email # Search for user by email
user = db.query(User).filter(User.email == email).first() user = db.query(User).filter(User.email == email).first()
if not user: if not user:
logger.warning(f"Attempt to resend verification email for non-existent email: {email}") logger.warning(
f"Attempt to resend verification email for non-existent email: {email}"
)
return False, "Email not found" return False, "Email not found"
if user.email_verified: if user.email_verified:
logger.info(f"Attempt to resend verification email for already verified email: {email}") logger.info(
f"Attempt to resend verification email for already verified email: {email}"
)
return False, "Email already verified" return False, "Email already verified"
# Generate new token # Generate new token
verification_token = generate_token() verification_token = generate_token()
token_expiry = datetime.utcnow() + timedelta(hours=24) token_expiry = datetime.utcnow() + timedelta(hours=24)
# Update user # Update user
user.verification_token = verification_token user.verification_token = verification_token
user.verification_token_expiry = token_expiry user.verification_token_expiry = token_expiry
db.commit() db.commit()
# Send email # Send email
email_sent = send_verification_email(user.email, verification_token) email_sent = send_verification_email(user.email, verification_token)
if not email_sent: if not email_sent:
logger.error(f"Failed to resend verification email to {user.email}") logger.error(f"Failed to resend verification email to {user.email}")
return False, "Failed to send verification email" return False, "Failed to send verification email"
logger.info(f"Verification email resent successfully to: {user.email}") logger.info(f"Verification email resent successfully to: {user.email}")
return True, "Verification email resent. Check your inbox." return True, "Verification email resent. Check your inbox."
except SQLAlchemyError as e: except SQLAlchemyError as e:
db.rollback() db.rollback()
logger.error(f"Error resending verification: {str(e)}") logger.error(f"Error resending verification: {str(e)}")
return False, f"Error resending verification: {str(e)}" return False, f"Error resending verification: {str(e)}"
except Exception as e: except Exception as e:
logger.error(f"Unexpected error resending verification: {str(e)}") logger.error(f"Unexpected error resending verification: {str(e)}")
return False, f"Unexpected error: {str(e)}" return False, f"Unexpected error: {str(e)}"
def forgot_password(db: Session, email: str) -> Tuple[bool, str]: def forgot_password(db: Session, email: str) -> Tuple[bool, str]:
""" """
Initiates the password recovery process Initiates the password recovery process
Args: Args:
db: Database session db: Database session
email: User email email: User email
Returns: Returns:
Tuple[bool, str]: Tuple with operation status and message Tuple[bool, str]: Tuple with operation status and message
""" """
try: try:
# Search for user by email # Search for user by email
user = db.query(User).filter(User.email == email).first() user = db.query(User).filter(User.email == email).first()
if not user: if not user:
# For security, we don't inform if the email exists or not # For security, we don't inform if the email exists or not
logger.info(f"Attempt to recover password for non-existent email: {email}") logger.info(f"Attempt to recover password for non-existent email: {email}")
return True, "If the email is registered, you will receive instructions to reset your password." return (
True,
"If the email is registered, you will receive instructions to reset your password.",
)
# Generate reset token # Generate reset token
reset_token = generate_token() reset_token = generate_token()
token_expiry = datetime.utcnow() + timedelta(hours=1) # Token valid for 1 hour token_expiry = datetime.utcnow() + timedelta(hours=1) # Token valid for 1 hour
# Update user # Update user
user.password_reset_token = reset_token user.password_reset_token = reset_token
user.password_reset_expiry = token_expiry user.password_reset_expiry = token_expiry
db.commit() db.commit()
# Send email # Send email
email_sent = send_password_reset_email(user.email, reset_token) email_sent = send_password_reset_email(user.email, reset_token)
if not email_sent: if not email_sent:
logger.error(f"Failed to send password reset email to {user.email}") logger.error(f"Failed to send password reset email to {user.email}")
return False, "Failed to send password reset email" return False, "Failed to send password reset email"
logger.info(f"Password reset email sent successfully to: {user.email}") logger.info(f"Password reset email sent successfully to: {user.email}")
return True, "If the email is registered, you will receive instructions to reset your password." return (
True,
"If the email is registered, you will receive instructions to reset your password.",
)
except SQLAlchemyError as e: except SQLAlchemyError as e:
db.rollback() db.rollback()
logger.error(f"Error processing password recovery: {str(e)}") logger.error(f"Error processing password recovery: {str(e)}")
return False, f"Error processing password recovery: {str(e)}" return False, f"Error processing password recovery: {str(e)}"
except Exception as e: except Exception as e:
logger.error(f"Unexpected error processing password recovery: {str(e)}") logger.error(f"Unexpected error processing password recovery: {str(e)}")
return False, f"Unexpected error: {str(e)}" return False, f"Unexpected error: {str(e)}"
def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, str]: def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, str]:
""" """
Resets the user's password using the provided token Resets the user's password using the provided token
Args: Args:
db: Database session db: Database session
token: Password reset token token: Password reset token
new_password: New password new_password: New password
Returns: Returns:
Tuple[bool, str]: Tuple with operation status and message Tuple[bool, str]: Tuple with operation status and message
""" """
try: try:
# Search for user by token # Search for user by token
user = db.query(User).filter(User.password_reset_token == token).first() user = db.query(User).filter(User.password_reset_token == token).first()
if not user: if not user:
logger.warning(f"Attempt to reset password with invalid token: {token}") logger.warning(f"Attempt to reset password with invalid token: {token}")
return False, "Invalid password reset token" return False, "Invalid password reset token"
# Check if the token has expired # Check if the token has expired
if user.password_reset_expiry < datetime.utcnow(): if user.password_reset_expiry < datetime.utcnow():
logger.warning(f"Attempt to reset password with expired token for user: {user.email}") logger.warning(
f"Attempt to reset password with expired token for user: {user.email}"
)
return False, "Password reset token expired" return False, "Password reset token expired"
# Update password # Update password
user.password_hash = get_password_hash(new_password) user.password_hash = get_password_hash(new_password)
user.password_reset_token = None user.password_reset_token = None
user.password_reset_expiry = None user.password_reset_expiry = None
db.commit() db.commit()
logger.info(f"Password reset successfully for user: {user.email}") logger.info(f"Password reset successfully for user: {user.email}")
return True, "Password reset successfully. You can now login with your new password." return (
True,
"Password reset successfully. You can now login with your new password.",
)
except SQLAlchemyError as e: except SQLAlchemyError as e:
db.rollback() db.rollback()
logger.error(f"Error resetting password: {str(e)}") logger.error(f"Error resetting password: {str(e)}")
return False, f"Error resetting password: {str(e)}" return False, f"Error resetting password: {str(e)}"
except Exception as e: except Exception as e:
logger.error(f"Unexpected error resetting password: {str(e)}") logger.error(f"Unexpected error resetting password: {str(e)}")
return False, f"Unexpected error: {str(e)}" return False, f"Unexpected error: {str(e)}"
def get_user_by_email(db: Session, email: str) -> Optional[User]: def get_user_by_email(db: Session, email: str) -> Optional[User]:
""" """
Searches for a user by email Searches for a user by email
Args: Args:
db: Database session db: Database session
email: User email email: User email
Returns: Returns:
Optional[User]: User found or None Optional[User]: User found or None
""" """
@ -292,15 +328,16 @@ def get_user_by_email(db: Session, email: str) -> Optional[User]:
logger.error(f"Error searching for user by email: {str(e)}") logger.error(f"Error searching for user by email: {str(e)}")
return None return None
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]: def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
""" """
Authenticates a user with email and password Authenticates a user with email and password
Args: Args:
db: Database session db: Database session
email: User email email: User email
password: User password password: User password
Returns: Returns:
Optional[User]: Authenticated user or None Optional[User]: Authenticated user or None
""" """
@ -313,75 +350,78 @@ def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
return None return None
return user return user
def get_admin_users(db: Session, skip: int = 0, limit: int = 100): def get_admin_users(db: Session, skip: int = 0, limit: int = 100):
""" """
Lists the admin users Lists the admin users
Args: Args:
db: Database session db: Database session
skip: Number of records to skip skip: Number of records to skip
limit: Maximum number of records to return limit: Maximum number of records to return
Returns: Returns:
List[User]: List of admin users List[User]: List of admin users
""" """
try: try:
users = db.query(User).filter(User.is_admin == True).offset(skip).limit(limit).all() users = db.query(User).filter(User.is_admin).offset(skip).limit(limit).all()
logger.info(f"List of admins: {len(users)} found") logger.info(f"List of admins: {len(users)} found")
return users return users
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Error listing admins: {str(e)}") logger.error(f"Error listing admins: {str(e)}")
return [] return []
except Exception as e: except Exception as e:
logger.error(f"Unexpected error listing admins: {str(e)}") logger.error(f"Unexpected error listing admins: {str(e)}")
return [] return []
def create_admin_user(db: Session, user_data: UserCreate) -> Tuple[Optional[User], str]: def create_admin_user(db: Session, user_data: UserCreate) -> Tuple[Optional[User], str]:
""" """
Creates a new admin user Creates a new admin user
Args: Args:
db: Database session db: Database session
user_data: User data to be created user_data: User data to be created
Returns: Returns:
Tuple[Optional[User], str]: Tuple with the created user (or None in case of error) and status message Tuple[Optional[User], str]: Tuple with the created user (or None in case of error) and status message
""" """
return create_user(db, user_data, is_admin=True) return create_user(db, user_data, is_admin=True)
def deactivate_user(db: Session, user_id: uuid.UUID) -> Tuple[bool, str]: def deactivate_user(db: Session, user_id: uuid.UUID) -> Tuple[bool, str]:
""" """
Deactivates a user (does not delete, only marks as inactive) Deactivates a user (does not delete, only marks as inactive)
Args: Args:
db: Database session db: Database session
user_id: ID of the user to be deactivated user_id: ID of the user to be deactivated
Returns: Returns:
Tuple[bool, str]: Tuple with operation status and message Tuple[bool, str]: Tuple with operation status and message
""" """
try: try:
# Search for user by ID # Search for user by ID
user = db.query(User).filter(User.id == user_id).first() user = db.query(User).filter(User.id == user_id).first()
if not user: if not user:
logger.warning(f"Attempt to deactivate non-existent user: {user_id}") logger.warning(f"Attempt to deactivate non-existent user: {user_id}")
return False, "User not found" return False, "User not found"
# Deactivate user # Deactivate user
user.is_active = False user.is_active = False
db.commit() db.commit()
logger.info(f"User deactivated successfully: {user.email}") logger.info(f"User deactivated successfully: {user.email}")
return True, "User deactivated successfully" return True, "User deactivated successfully"
except SQLAlchemyError as e: except SQLAlchemyError as e:
db.rollback() db.rollback()
logger.error(f"Error deactivating user: {str(e)}") logger.error(f"Error deactivating user: {str(e)}")
return False, f"Error deactivating user: {str(e)}" return False, f"Error deactivating user: {str(e)}"
except Exception as e: except Exception as e:
logger.error(f"Unexpected error deactivating user: {str(e)}") logger.error(f"Unexpected error deactivating user: {str(e)}")
return False, f"Unexpected error: {str(e)}" return False, f"Unexpected error: {str(e)}"

View File

@ -3,23 +3,26 @@ import os
import sys import sys
from src.config.settings import settings from src.config.settings import settings
class CustomFormatter(logging.Formatter): class CustomFormatter(logging.Formatter):
"""Custom formatter for logs""" """Custom formatter for logs"""
grey = "\x1b[38;20m" grey = "\x1b[38;20m"
yellow = "\x1b[33;20m" yellow = "\x1b[33;20m"
red = "\x1b[31;20m" red = "\x1b[31;20m"
bold_red = "\x1b[31;1m" bold_red = "\x1b[31;1m"
reset = "\x1b[0m" reset = "\x1b[0m"
format_template = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)" format_template = (
"%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
)
FORMATS = { FORMATS = {
logging.DEBUG: grey + format_template + reset, logging.DEBUG: grey + format_template + reset,
logging.INFO: grey + format_template + reset, logging.INFO: grey + format_template + reset,
logging.WARNING: yellow + format_template + reset, logging.WARNING: yellow + format_template + reset,
logging.ERROR: red + format_template + reset, logging.ERROR: red + format_template + reset,
logging.CRITICAL: bold_red + format_template + reset logging.CRITICAL: bold_red + format_template + reset,
} }
def format(self, record): def format(self, record):
@ -27,33 +30,34 @@ class CustomFormatter(logging.Formatter):
formatter = logging.Formatter(log_fmt) formatter = logging.Formatter(log_fmt)
return formatter.format(record) return formatter.format(record)
def setup_logger(name: str) -> logging.Logger: def setup_logger(name: str) -> logging.Logger:
""" """
Configures a custom logger Configures a custom logger
Args: Args:
name: Logger name name: Logger name
Returns: Returns:
logging.Logger: Logger configurado logging.Logger: Logger configurado
""" """
logger = logging.getLogger(name) logger = logging.getLogger(name)
# Remove existing handlers to avoid duplication # Remove existing handlers to avoid duplication
if logger.handlers: if logger.handlers:
logger.handlers.clear() logger.handlers.clear()
# Configure the logger level based on the environment variable or configuration # Configure the logger level based on the environment variable or configuration
log_level = getattr(logging, os.getenv("LOG_LEVEL", settings.LOG_LEVEL).upper()) log_level = getattr(logging, os.getenv("LOG_LEVEL", settings.LOG_LEVEL).upper())
logger.setLevel(log_level) logger.setLevel(log_level)
# Console handler # Console handler
console_handler = logging.StreamHandler(sys.stdout) console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(CustomFormatter()) console_handler.setFormatter(CustomFormatter())
console_handler.setLevel(log_level) console_handler.setLevel(log_level)
logger.addHandler(console_handler) logger.addHandler(console_handler)
# Prevent logs from being propagated to the root logger # Prevent logs from being propagated to the root logger
logger.propagate = False logger.propagate = False
return logger return logger

View File

@ -11,41 +11,44 @@ from dataclasses import dataclass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Fix bcrypt error with passlib # Fix bcrypt error with passlib
if not hasattr(bcrypt, '__about__'): if not hasattr(bcrypt, "__about__"):
@dataclass @dataclass
class BcryptAbout: class BcryptAbout:
__version__: str = getattr(bcrypt, "__version__") __version__: str = getattr(bcrypt, "__version__")
setattr(bcrypt, "__about__", BcryptAbout()) setattr(bcrypt, "__about__", BcryptAbout())
# Context for password hashing using bcrypt # Context for password hashing using bcrypt
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_password_hash(password: str) -> str: def get_password_hash(password: str) -> str:
"""Creates a password hash""" """Creates a password hash"""
return pwd_context.hash(password) return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool: def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verifies if the provided password matches the stored hash""" """Verifies if the provided password matches the stored hash"""
return pwd_context.verify(plain_password, hashed_password) return pwd_context.verify(plain_password, hashed_password)
def create_jwt_token(data: dict, expires_delta: timedelta = None) -> str: def create_jwt_token(data: dict, expires_delta: timedelta = None) -> str:
"""Creates a JWT token""" """Creates a JWT token"""
to_encode = data.copy() to_encode = data.copy()
if expires_delta: if expires_delta:
expire = datetime.utcnow() + expires_delta expire = datetime.utcnow() + expires_delta
else: else:
expire = datetime.utcnow() + timedelta( expire = datetime.utcnow() + timedelta(minutes=settings.JWT_EXPIRATION_TIME)
minutes=settings.JWT_EXPIRATION_TIME
)
to_encode.update({"exp": expire}) to_encode.update({"exp": expire})
encoded_jwt = jwt.encode( encoded_jwt = jwt.encode(
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
) )
return encoded_jwt return encoded_jwt
def generate_token(length: int = 32) -> str: def generate_token(length: int = 32) -> str:
"""Generates a secure token for email verification or password reset""" """Generates a secure token for email verification or password reset"""
alphabet = string.ascii_letters + string.digits alphabet = string.ascii_letters + string.digits
token = ''.join(secrets.choice(alphabet) for _ in range(length)) token = "".join(secrets.choice(alphabet) for _ in range(length))
return token return token

1
tests/__init__.py Normal file
View File

@ -0,0 +1 @@
# Package initialization for tests

1
tests/api/__init__.py Normal file
View File

@ -0,0 +1 @@
# API tests package

11
tests/api/test_root.py Normal file
View File

@ -0,0 +1,11 @@
def test_read_root(client):
"""
Test that the root endpoint returns the correct response.
"""
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert "message" in data
assert "documentation" in data
assert "version" in data
assert "auth" in data

View File

@ -0,0 +1 @@
# Services tests package

View File

@ -0,0 +1,27 @@
from src.services.auth_service import create_access_token
from src.models.models import User
import uuid
def test_create_access_token():
"""
Test that an access token is created with the correct data.
"""
# Create a mock user
user = User(
id=uuid.uuid4(),
email="test@example.com",
hashed_password="hashed_password",
is_active=True,
is_admin=False,
name="Test User",
client_id=uuid.uuid4(),
)
# Create token
token = create_access_token(user)
# Simple validation
assert token is not None
assert isinstance(token, str)
assert len(token) > 0