chore: update project structure and add testing framework
This commit is contained in:
parent
7af234ef48
commit
e7e030dfd5
77
.cursorrules
77
.cursorrules
@ -17,9 +17,16 @@
|
||||
```
|
||||
src/
|
||||
├── api/
|
||||
│ ├── routes.py # API routes definition
|
||||
│ ├── auth_routes.py # Authentication routes (login, registration, etc.)
|
||||
│ └── admin_routes.py # Protected admin routes
|
||||
│ ├── __init__.py # Package initialization
|
||||
│ ├── admin_routes.py # Admin routes for management interface
|
||||
│ ├── agent_routes.py # Routes to manage agents
|
||||
│ ├── auth_routes.py # Authentication routes (login, registration)
|
||||
│ ├── chat_routes.py # Routes for chat interactions with agents
|
||||
│ ├── client_routes.py # Routes to manage clients
|
||||
│ ├── contact_routes.py # Routes to manage contacts
|
||||
│ ├── mcp_server_routes.py # Routes to manage MCP servers
|
||||
│ ├── session_routes.py # Routes to manage chat sessions
|
||||
│ └── tool_routes.py # Routes to manage tools for agents
|
||||
├── config/
|
||||
│ ├── database.py # Database configuration
|
||||
│ └── settings.py # General settings
|
||||
@ -30,18 +37,21 @@ src/
|
||||
│ └── models.py # SQLAlchemy models
|
||||
├── schemas/
|
||||
│ ├── schemas.py # Main Pydantic schemas
|
||||
│ ├── chat.py # Chat schemas
|
||||
│ ├── user.py # User and authentication schemas
|
||||
│ └── audit.py # Audit logs schemas
|
||||
├── services/
|
||||
│ ├── agent_service.py # Business logic for agents
|
||||
│ ├── agent_runner.py # Agent execution logic
|
||||
│ ├── auth_service.py # JWT authentication logic
|
||||
│ ├── audit_service.py # Audit logs logic
|
||||
│ ├── client_service.py # Business logic for clients
|
||||
│ ├── contact_service.py # Business logic for contacts
|
||||
│ ├── mcp_server_service.py # Business logic for MCP servers
|
||||
│ ├── tool_service.py # Business logic for tools
|
||||
│ ├── user_service.py # User and authentication logic
|
||||
│ ├── auth_service.py # JWT authentication logic
|
||||
│ ├── email_service.py # Email sending service
|
||||
│ └── audit_service.py # Audit logs logic
|
||||
│ ├── mcp_server_service.py # Business logic for MCP servers
|
||||
│ ├── session_service.py # Business logic for chat sessions
|
||||
│ ├── tool_service.py # Business logic for tools
|
||||
│ └── user_service.py # User management logic
|
||||
├── templates/
|
||||
│ ├── emails/
|
||||
│ │ ├── base_email.html # Base template with common structure and styles
|
||||
@ -49,7 +59,21 @@ src/
|
||||
│ │ ├── password_reset.html # Password reset template
|
||||
│ │ ├── welcome_email.html # Welcome email after verification
|
||||
│ │ └── account_locked.html # Security alert for locked accounts
|
||||
├── tests/
|
||||
│ ├── __init__.py # Package initialization
|
||||
│ ├── api/
|
||||
│ │ ├── __init__.py # Package initialization
|
||||
│ │ ├── test_auth_routes.py # Test for authentication routes
|
||||
│ │ └── test_root.py # Test for root endpoint
|
||||
│ ├── models/
|
||||
│ │ ├── __init__.py # Package initialization
|
||||
│ │ ├── test_models.py # Test for models
|
||||
│ ├── services/
|
||||
│ │ ├── __init__.py # Package initialization
|
||||
│ │ ├── test_auth_service.py # Test for authentication service
|
||||
│ │ └── test_user_service.py # Test for user service
|
||||
└── utils/
|
||||
├── logger.py # Logger configuration
|
||||
└── security.py # Security utilities (JWT, hash)
|
||||
```
|
||||
|
||||
@ -63,6 +87,15 @@ src/
|
||||
- Code examples in documentation must be in English
|
||||
- Commit messages must be in English
|
||||
|
||||
### Project Configuration
|
||||
- Dependencies managed in `pyproject.toml` using modern Python packaging standards
|
||||
- Development dependencies specified as optional dependencies in `pyproject.toml`
|
||||
- Single source of truth for project metadata in `pyproject.toml`
|
||||
- Build system configured to use setuptools
|
||||
- Pytest configuration in `pyproject.toml` under `[tool.pytest.ini_options]`
|
||||
- Code formatting with Black configured in `pyproject.toml`
|
||||
- Linting with Flake8 configured in `.flake8`
|
||||
|
||||
### Schemas (Pydantic)
|
||||
- Use `BaseModel` as base for all schemas
|
||||
- Define fields with explicit types
|
||||
@ -136,6 +169,28 @@ src/
|
||||
- Indentation with 4 spaces
|
||||
- Maximum of 79 characters per line
|
||||
|
||||
## Commit Rules
|
||||
- Use Conventional Commits format for all commit messages
|
||||
- Format: `<type>(<scope>): <description>`
|
||||
- Types:
|
||||
- `feat`: A new feature
|
||||
- `fix`: A bug fix
|
||||
- `docs`: Documentation changes
|
||||
- `style`: Changes that do not affect code meaning (formatting, etc.)
|
||||
- `refactor`: Code changes that neither fix a bug nor add a feature
|
||||
- `perf`: Performance improvements
|
||||
- `test`: Adding or modifying tests
|
||||
- `chore`: Changes to build process or auxiliary tools
|
||||
- Scope is optional and should be the module or component affected
|
||||
- Description should be concise, in the imperative mood, and not capitalized
|
||||
- Use body for more detailed explanations if needed
|
||||
- Reference issues in the footer with `Fixes #123` or `Relates to #123`
|
||||
- Examples:
|
||||
- `feat(auth): add password reset functionality`
|
||||
- `fix(api): correct validation error in client registration`
|
||||
- `docs: update API documentation for new endpoints`
|
||||
- `refactor(services): improve error handling in authentication`
|
||||
|
||||
## Best Practices
|
||||
- Always validate input data
|
||||
- Implement appropriate logging
|
||||
@ -163,6 +218,7 @@ src/
|
||||
|
||||
## Useful Commands
|
||||
- `make run`: Start the server
|
||||
- `make run-prod`: Start the server in production mode
|
||||
- `make alembic-revision message="description"`: Create new migration
|
||||
- `make alembic-upgrade`: Apply pending migrations
|
||||
- `make alembic-downgrade`: Revert last migration
|
||||
@ -170,3 +226,8 @@ src/
|
||||
- `make alembic-reset`: Reset database to initial state
|
||||
- `make alembic-upgrade-cascade`: Force upgrade removing dependencies
|
||||
- `make clear-cache`: Clean project cache
|
||||
- `make seed-all`: Run all database seeders
|
||||
- `make lint`: Run linting checks with flake8
|
||||
- `make format`: Format code with black
|
||||
- `make install`: Install project for development
|
||||
- `make install-dev`: Install project with development dependencies
|
||||
|
@ -1,3 +1,47 @@
|
||||
# Environment and IDE
|
||||
.venv
|
||||
venv
|
||||
.env
|
||||
.idea
|
||||
.vscode
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
.pytest_cache
|
||||
.coverage
|
||||
htmlcov/
|
||||
.tox/
|
||||
|
||||
# Version control
|
||||
.git
|
||||
.github
|
||||
.gitignore
|
||||
|
||||
# Logs and temp files
|
||||
logs
|
||||
*.log
|
||||
tmp
|
||||
.DS_Store
|
||||
|
||||
# Docker
|
||||
.dockerignore
|
||||
Dockerfile*
|
||||
docker-compose*
|
||||
|
||||
# Documentation
|
||||
README.md
|
||||
LICENSE
|
||||
docs/
|
||||
|
||||
# Development tools
|
||||
tests/
|
||||
.flake8
|
||||
pyproject.toml
|
||||
requirements-dev.txt
|
||||
Makefile
|
||||
|
||||
# Ambiente virtual
|
||||
venv/
|
||||
__pycache__/
|
||||
|
8
.flake8
Normal file
8
.flake8
Normal file
@ -0,0 +1,8 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
exclude = .git,__pycache__,venv,alembic/versions/*
|
||||
ignore = E203, W503, E501
|
||||
per-file-ignores =
|
||||
__init__.py: F401
|
||||
src/models/models.py: E712
|
||||
alembic/*: E711,E712,F401
|
77
.gitignore
vendored
77
.gitignore
vendored
@ -1,13 +1,10 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
env/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
@ -19,19 +16,10 @@ lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
@ -42,14 +30,57 @@ nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.sublime-project
|
||||
*.sublime-workspace
|
||||
.DS_Store
|
||||
|
||||
# Logs
|
||||
logs/
|
||||
*.log
|
||||
|
||||
# Database
|
||||
*.db
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
backup/
|
||||
|
||||
# Local
|
||||
local_settings.py
|
||||
local.py
|
||||
|
||||
# Docker
|
||||
.docker/
|
||||
|
||||
# Alembic versions
|
||||
# alembic/versions/
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
@ -77,17 +108,6 @@ celerybeat-schedule
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
.venv/
|
||||
.env/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
@ -102,11 +122,8 @@ venv.bak/
|
||||
.mypy_cache/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
16
Dockerfile
16
Dockerfile
@ -1,4 +1,4 @@
|
||||
FROM python:3.11-slim
|
||||
FROM python:3.10-slim
|
||||
|
||||
# Define o diretório de trabalho
|
||||
WORKDIR /app
|
||||
@ -15,19 +15,19 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copia os arquivos de requisitos
|
||||
COPY requirements.txt .
|
||||
|
||||
# Instala as dependências
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copia o código-fonte
|
||||
# Copy project files
|
||||
COPY . .
|
||||
|
||||
# Install dependencies
|
||||
RUN pip install --no-cache-dir -e .
|
||||
|
||||
# Configuração para produção
|
||||
ENV PORT=8000 \
|
||||
HOST=0.0.0.0 \
|
||||
DEBUG=false
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Define o comando de inicialização
|
||||
CMD alembic upgrade head && uvicorn src.main:app --host $HOST --port $PORT
|
51
Makefile
51
Makefile
@ -1,42 +1,42 @@
|
||||
.PHONY: migrate init revision upgrade downgrade run seed-admin seed-client seed-agents seed-mcp-servers seed-tools seed-contacts seed-all docker-build docker-up docker-down docker-logs
|
||||
.PHONY: migrate init revision upgrade downgrade run seed-admin seed-client seed-agents seed-mcp-servers seed-tools seed-contacts seed-all docker-build docker-up docker-down docker-logs lint format install install-dev venv
|
||||
|
||||
# Comandos do Alembic
|
||||
# Alembic commands
|
||||
init:
|
||||
alembic init alembics
|
||||
|
||||
# make alembic-revision message="descrição da migração"
|
||||
# make alembic-revision message="migration description"
|
||||
alembic-revision:
|
||||
alembic revision --autogenerate -m "$(message)"
|
||||
|
||||
# Comando para atualizar o banco de dados
|
||||
# Command to update database to latest version
|
||||
alembic-upgrade:
|
||||
alembic upgrade head
|
||||
|
||||
# Comando para voltar uma versão
|
||||
# Command to downgrade one version
|
||||
alembic-downgrade:
|
||||
alembic downgrade -1
|
||||
|
||||
# Comando para rodar o servidor
|
||||
# Command to run the server
|
||||
run:
|
||||
uvicorn src.main:app --reload --host 0.0.0.0 --port 8000 --reload-dir src
|
||||
|
||||
# Comando para limpar o cache em todas as pastas do projeto pastas pycache
|
||||
# Command to run the server in production mode
|
||||
run-prod:
|
||||
uvicorn src.main:app --host 0.0.0.0 --port 8000 --workers 4
|
||||
|
||||
# Command to clean cache in all project folders
|
||||
clear-cache:
|
||||
rm -rf ~/.cache/uv/environments-v2/* && find . -type d -name "__pycache__" -exec rm -r {} +
|
||||
|
||||
# Comando para criar uma nova migração
|
||||
# Command to create a new migration and apply it
|
||||
alembic-migrate:
|
||||
alembic revision --autogenerate -m "$(message)" && alembic upgrade head
|
||||
|
||||
# Comando para resetar o banco de dados
|
||||
# Command to reset the database
|
||||
alembic-reset:
|
||||
alembic downgrade base && alembic upgrade head
|
||||
|
||||
# Comando para forçar upgrade com CASCADE
|
||||
alembic-upgrade-cascade:
|
||||
psql -U postgres -d a2a_saas -c "DROP TABLE IF EXISTS events CASCADE; DROP TABLE IF EXISTS sessions CASCADE; DROP TABLE IF EXISTS user_states CASCADE; DROP TABLE IF EXISTS app_states CASCADE;" && alembic upgrade head
|
||||
|
||||
# Comandos para executar seeders
|
||||
|
||||
# Commands to run seeders
|
||||
seed-admin:
|
||||
python -m scripts.seeders.admin_seeder
|
||||
|
||||
@ -58,7 +58,7 @@ seed-contacts:
|
||||
seed-all:
|
||||
python -m scripts.run_seeders
|
||||
|
||||
# Comandos Docker
|
||||
# Docker commands
|
||||
docker-build:
|
||||
docker-compose build
|
||||
|
||||
@ -72,4 +72,21 @@ docker-logs:
|
||||
docker-compose logs -f
|
||||
|
||||
docker-seed:
|
||||
docker-compose exec api python -m scripts.run_seeders
|
||||
docker-compose exec api python -m scripts.run_seeders
|
||||
|
||||
# Testing, linting and formatting commands
|
||||
lint:
|
||||
flake8 src/ tests/
|
||||
|
||||
format:
|
||||
black src/ tests/
|
||||
|
||||
# Virtual environment and installation commands
|
||||
venv:
|
||||
python -m venv venv
|
||||
|
||||
install:
|
||||
pip install -e .
|
||||
|
||||
install-dev:
|
||||
pip install -e ".[dev]"
|
53
conftest.py
Normal file
53
conftest.py
Normal file
@ -0,0 +1,53 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from src.config.database import Base, get_db
|
||||
from src.main import app
|
||||
|
||||
# Use in-memory SQLite for tests
|
||||
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session():
|
||||
"""Creates a fresh database session for each test."""
|
||||
Base.metadata.create_all(bind=engine) # Create tables
|
||||
|
||||
connection = engine.connect()
|
||||
transaction = connection.begin()
|
||||
session = TestingSessionLocal(bind=connection)
|
||||
|
||||
# Use our test database instead of the standard one
|
||||
def override_get_db():
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
yield session # The test will run here
|
||||
|
||||
# Teardown
|
||||
transaction.rollback()
|
||||
connection.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(db_session):
|
||||
"""Creates a FastAPI TestClient with database session fixture."""
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
89
pyproject.toml
Normal file
89
pyproject.toml
Normal file
@ -0,0 +1,89 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=42", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "evo-ai"
|
||||
version = "1.0.0"
|
||||
description = "API for executing AI agents"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{name = "EvoAI Team", email = "admin@evoai.com"}
|
||||
]
|
||||
requires-python = ">=3.10"
|
||||
license = {text = "Proprietary"}
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"License :: Other/Proprietary License",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
|
||||
# Main dependencies
|
||||
dependencies = [
|
||||
"fastapi==0.115.12",
|
||||
"uvicorn==0.34.2",
|
||||
"pydantic==2.11.3",
|
||||
"sqlalchemy==2.0.40",
|
||||
"psycopg2-binary==2.9.10",
|
||||
"google-cloud-aiplatform==1.90.0",
|
||||
"python-dotenv==1.1.0",
|
||||
"google-adk==0.3.0",
|
||||
"litellm==1.67.4.post1",
|
||||
"python-multipart==0.0.20",
|
||||
"alembic==1.15.2",
|
||||
"asyncpg==0.30.0",
|
||||
"python-jose==3.4.0",
|
||||
"passlib==1.7.4",
|
||||
"sendgrid==6.11.0",
|
||||
"pydantic-settings==2.9.1",
|
||||
"fastapi_utils==0.8.0",
|
||||
"bcrypt==4.3.0",
|
||||
"jinja2==3.1.6",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"black==25.1.0",
|
||||
"flake8==7.2.0",
|
||||
"pytest==8.3.5",
|
||||
"pytest-cov==6.1.1",
|
||||
"httpx==0.28.1",
|
||||
"pytest-asyncio==0.26.0",
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["src"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ["py310"]
|
||||
include = '\.pyi?$'
|
||||
exclude = '''
|
||||
/(
|
||||
\.git
|
||||
| \.hg
|
||||
| \.mypy_cache
|
||||
| \.tox
|
||||
| \.venv
|
||||
| _build
|
||||
| buck-out
|
||||
| build
|
||||
| dist
|
||||
| venv
|
||||
| alembic/versions
|
||||
)/
|
||||
'''
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = "test_*.py"
|
||||
python_functions = "test_*"
|
||||
python_classes = "Test*"
|
||||
filterwarnings = [
|
||||
"ignore::DeprecationWarning",
|
||||
]
|
||||
|
||||
[tool.coverage.run]
|
||||
source = ["src"]
|
||||
omit = ["tests/*", "alembic/*"]
|
@ -1,22 +0,0 @@
|
||||
fastapi
|
||||
uvicorn
|
||||
pydantic
|
||||
sqlalchemy
|
||||
psycopg2
|
||||
psycopg2-binary
|
||||
google-cloud-aiplatform
|
||||
python-dotenv
|
||||
google-adk
|
||||
litellm
|
||||
python-multipart
|
||||
alembic
|
||||
asyncpg
|
||||
# Novas dependências para autenticação
|
||||
python-jose[cryptography]
|
||||
passlib[bcrypt]
|
||||
sendgrid
|
||||
pydantic[email]
|
||||
pydantic-settings
|
||||
fastapi_utils
|
||||
bcrypt
|
||||
jinja2
|
6
setup.py
Normal file
6
setup.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""Setup script for the package."""
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup()
|
@ -1,3 +1,3 @@
|
||||
"""
|
||||
Pacote principal da aplicação
|
||||
"""
|
||||
Main package of the application
|
||||
"""
|
||||
|
@ -1,14 +1,17 @@
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from src.config.database import get_db
|
||||
from src.core.jwt_middleware import get_jwt_token, verify_admin
|
||||
from src.schemas.audit import AuditLogResponse, AuditLogFilter
|
||||
from src.services.audit_service import get_audit_logs, create_audit_log
|
||||
from src.services.user_service import get_admin_users, create_admin_user, deactivate_user
|
||||
from src.services.user_service import (
|
||||
get_admin_users,
|
||||
create_admin_user,
|
||||
deactivate_user,
|
||||
)
|
||||
from src.schemas.user import UserResponse, AdminUserCreate
|
||||
|
||||
router = APIRouter(
|
||||
@ -18,6 +21,7 @@ router = APIRouter(
|
||||
responses={403: {"description": "Permission denied"}},
|
||||
)
|
||||
|
||||
|
||||
# Audit routes
|
||||
@router.get("/audit-logs", response_model=List[AuditLogResponse])
|
||||
async def read_audit_logs(
|
||||
@ -27,12 +31,12 @@ async def read_audit_logs(
|
||||
):
|
||||
"""
|
||||
Get audit logs with optional filters
|
||||
|
||||
|
||||
Args:
|
||||
filters: Filters for log search
|
||||
db: Database session
|
||||
payload: JWT token payload
|
||||
|
||||
|
||||
Returns:
|
||||
List[AuditLogResponse]: List of audit logs
|
||||
"""
|
||||
@ -45,9 +49,10 @@ async def read_audit_logs(
|
||||
resource_type=filters.resource_type,
|
||||
resource_id=filters.resource_id,
|
||||
start_date=filters.start_date,
|
||||
end_date=filters.end_date
|
||||
end_date=filters.end_date,
|
||||
)
|
||||
|
||||
|
||||
# Admin routes
|
||||
@router.get("/users", response_model=List[UserResponse])
|
||||
async def read_admin_users(
|
||||
@ -58,18 +63,19 @@ async def read_admin_users(
|
||||
):
|
||||
"""
|
||||
List admin users
|
||||
|
||||
|
||||
Args:
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
db: Database session
|
||||
payload: JWT token payload
|
||||
|
||||
|
||||
Returns:
|
||||
List[UserResponse]: List of admin users
|
||||
"""
|
||||
return get_admin_users(db, skip, limit)
|
||||
|
||||
|
||||
@router.post("/users", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_new_admin_user(
|
||||
user_data: AdminUserCreate,
|
||||
@ -79,16 +85,16 @@ async def create_new_admin_user(
|
||||
):
|
||||
"""
|
||||
Create a new admin user
|
||||
|
||||
|
||||
Args:
|
||||
user_data: User data to be created
|
||||
request: FastAPI Request object
|
||||
db: Database session
|
||||
payload: JWT token payload
|
||||
|
||||
|
||||
Returns:
|
||||
UserResponse: Created user data
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error in creation
|
||||
"""
|
||||
@ -97,17 +103,14 @@ async def create_new_admin_user(
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unable to identify the logged in user"
|
||||
detail="Unable to identify the logged in user",
|
||||
)
|
||||
|
||||
|
||||
# Create admin user
|
||||
user, message = create_admin_user(db, user_data)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=message
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message)
|
||||
|
||||
# Register action in audit log
|
||||
create_audit_log(
|
||||
db,
|
||||
@ -116,11 +119,12 @@ async def create_new_admin_user(
|
||||
resource_type="admin_user",
|
||||
resource_id=str(user.id),
|
||||
details={"email": user.email},
|
||||
request=request
|
||||
request=request,
|
||||
)
|
||||
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def deactivate_admin_user(
|
||||
user_id: uuid.UUID,
|
||||
@ -130,13 +134,13 @@ async def deactivate_admin_user(
|
||||
):
|
||||
"""
|
||||
Deactivate an admin user (does not delete, only deactivates)
|
||||
|
||||
|
||||
Args:
|
||||
user_id: ID of the user to be deactivated
|
||||
request: FastAPI Request object
|
||||
db: Database session
|
||||
payload: JWT token payload
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error in deactivation
|
||||
"""
|
||||
@ -145,24 +149,21 @@ async def deactivate_admin_user(
|
||||
if not current_user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unable to identify the logged in user"
|
||||
detail="Unable to identify the logged in user",
|
||||
)
|
||||
|
||||
|
||||
# Do not allow deactivating yourself
|
||||
if str(user_id) == current_user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Unable to deactivate your own user"
|
||||
detail="Unable to deactivate your own user",
|
||||
)
|
||||
|
||||
|
||||
# Deactivate user
|
||||
success, message = deactivate_user(db, user_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=message
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message)
|
||||
|
||||
# Register action in audit log
|
||||
create_audit_log(
|
||||
db,
|
||||
@ -171,5 +172,5 @@ async def deactivate_admin_user(
|
||||
resource_type="admin_user",
|
||||
resource_id=str(user_id),
|
||||
details=None,
|
||||
request=request
|
||||
)
|
||||
request=request,
|
||||
)
|
||||
|
@ -7,10 +7,6 @@ from src.core.jwt_middleware import (
|
||||
get_jwt_token,
|
||||
verify_user_client,
|
||||
)
|
||||
from src.core.jwt_middleware import (
|
||||
get_jwt_token,
|
||||
verify_user_client,
|
||||
)
|
||||
from src.schemas.schemas import (
|
||||
Agent,
|
||||
AgentCreate,
|
||||
|
@ -162,9 +162,7 @@ async def login_for_access_token(form_data: UserLogin, db: Session = Depends(get
|
||||
"""
|
||||
user = authenticate_user(db, form_data.email, form_data.password)
|
||||
if not user:
|
||||
logger.warning(
|
||||
f"Login attempt with invalid credentials: {form_data.email}"
|
||||
)
|
||||
logger.warning(f"Login attempt with invalid credentials: {form_data.email}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
|
@ -11,7 +11,11 @@ from src.services import (
|
||||
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse
|
||||
from src.services.agent_runner import run_agent
|
||||
from src.core.exceptions import AgentNotFoundError
|
||||
from src.main import session_service, artifacts_service, memory_service
|
||||
from src.services.service_providers import (
|
||||
session_service,
|
||||
artifacts_service,
|
||||
memory_service,
|
||||
)
|
||||
|
||||
from datetime import datetime
|
||||
import logging
|
||||
@ -71,4 +75,4 @@ async def chat(
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
)
|
||||
)
|
||||
|
@ -19,7 +19,7 @@ from src.services.session_service import (
|
||||
get_sessions_by_agent,
|
||||
get_sessions_by_client,
|
||||
)
|
||||
from src.main import session_service
|
||||
from src.services.service_providers import session_service
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -30,6 +30,7 @@ router = APIRouter(
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
# Session Routes
|
||||
@router.get("/client/{client_id}", response_model=List[Adk_Session])
|
||||
async def get_client_sessions(
|
||||
|
@ -10,9 +10,10 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
db.close()
|
||||
|
@ -1,55 +1,66 @@
|
||||
from fastapi import HTTPException
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class BaseAPIException(HTTPException):
|
||||
"""Base class for API exceptions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
error_code: str,
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(status_code=status_code, detail={
|
||||
"error": message,
|
||||
"error_code": error_code,
|
||||
"details": details or {}
|
||||
})
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
detail={
|
||||
"error": message,
|
||||
"error_code": error_code,
|
||||
"details": details or {},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class AgentNotFoundError(BaseAPIException):
|
||||
"""Exception when the agent is not found"""
|
||||
|
||||
def __init__(self, agent_id: str):
|
||||
super().__init__(
|
||||
status_code=404,
|
||||
message=f"Agent with ID {agent_id} not found",
|
||||
error_code="AGENT_NOT_FOUND"
|
||||
error_code="AGENT_NOT_FOUND",
|
||||
)
|
||||
|
||||
|
||||
class InvalidParameterError(BaseAPIException):
|
||||
"""Exception for invalid parameters"""
|
||||
|
||||
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||||
super().__init__(
|
||||
status_code=400,
|
||||
message=message,
|
||||
error_code="INVALID_PARAMETER",
|
||||
details=details
|
||||
details=details,
|
||||
)
|
||||
|
||||
|
||||
class InvalidRequestError(BaseAPIException):
|
||||
"""Exception for invalid requests"""
|
||||
|
||||
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||||
super().__init__(
|
||||
status_code=400,
|
||||
message=message,
|
||||
error_code="INVALID_REQUEST",
|
||||
details=details
|
||||
details=details,
|
||||
)
|
||||
|
||||
|
||||
class InternalServerError(BaseAPIException):
|
||||
"""Exception for server errors"""
|
||||
|
||||
def __init__(self, message: str = "Server error"):
|
||||
super().__init__(
|
||||
status_code=500,
|
||||
message=message,
|
||||
error_code="INTERNAL_SERVER_ERROR"
|
||||
)
|
||||
status_code=500, message=message, error_code="INTERNAL_SERVER_ERROR"
|
||||
)
|
||||
|
@ -13,16 +13,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||
|
||||
|
||||
async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict:
|
||||
"""
|
||||
Extracts and validates the JWT token
|
||||
|
||||
|
||||
Args:
|
||||
token: Token JWT
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Token payload data
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If the token is invalid
|
||||
"""
|
||||
@ -31,86 +32,90 @@ async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict:
|
||||
detail="Invalid credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM]
|
||||
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
|
||||
|
||||
email: str = payload.get("sub")
|
||||
if email is None:
|
||||
logger.warning("Token without email (sub)")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
exp = payload.get("exp")
|
||||
if exp is None or datetime.fromtimestamp(exp) < datetime.utcnow():
|
||||
logger.warning(f"Token expired for {email}")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
except JWTError as e:
|
||||
logger.error(f"Error decoding JWT token: {str(e)}")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
async def verify_user_client(
|
||||
payload: dict = Depends(get_jwt_token),
|
||||
db: Session = Depends(get_db),
|
||||
required_client_id: UUID = None
|
||||
required_client_id: UUID = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Verifies if the user is associated with the specified client
|
||||
|
||||
|
||||
Args:
|
||||
payload: Token JWT payload
|
||||
db: Database session
|
||||
required_client_id: Client ID to be verified
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True se verificado com sucesso
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If the user does not have permission
|
||||
"""
|
||||
# Administrators have access to all clients
|
||||
if payload.get("is_admin", False):
|
||||
return True
|
||||
|
||||
|
||||
# Para não-admins, verificar se o client_id corresponde
|
||||
user_client_id = payload.get("client_id")
|
||||
if not user_client_id:
|
||||
logger.warning(f"Non-admin user without client_id in token: {payload.get('sub')}")
|
||||
logger.warning(
|
||||
f"Non-admin user without client_id in token: {payload.get('sub')}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User not associated with a client"
|
||||
detail="User not associated with a client",
|
||||
)
|
||||
|
||||
|
||||
# If no client_id is specified to verify, any client is valid
|
||||
if not required_client_id:
|
||||
return True
|
||||
|
||||
|
||||
# Verify if the user's client_id corresponds to the required_client_id
|
||||
if str(user_client_id) != str(required_client_id):
|
||||
logger.warning(f"Access denied: User {payload.get('sub')} tried to access resources of client {required_client_id}")
|
||||
logger.warning(
|
||||
f"Access denied: User {payload.get('sub')} tried to access resources of client {required_client_id}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to access resources of this client"
|
||||
detail="Access denied to access resources of this client",
|
||||
)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def verify_admin(payload: dict = Depends(get_jwt_token)) -> bool:
|
||||
"""
|
||||
Verifies if the user is an administrator
|
||||
|
||||
|
||||
Args:
|
||||
payload: Token JWT payload
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if the user is an administrator
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If the user is not an administrator
|
||||
"""
|
||||
@ -118,26 +123,29 @@ async def verify_admin(payload: dict = Depends(get_jwt_token)) -> bool:
|
||||
logger.warning(f"Access denied to admin: User {payload.get('sub')}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. Restricted to administrators."
|
||||
detail="Access denied. Restricted to administrators.",
|
||||
)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
def get_current_user_client_id(payload: dict = Depends(get_jwt_token)) -> Optional[UUID]:
|
||||
|
||||
def get_current_user_client_id(
|
||||
payload: dict = Depends(get_jwt_token),
|
||||
) -> Optional[UUID]:
|
||||
"""
|
||||
Gets the ID of the client associated with the current user
|
||||
|
||||
|
||||
Args:
|
||||
payload: Token JWT payload
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[UUID]: Client ID or None if the user is an administrator
|
||||
"""
|
||||
if payload.get("is_admin", False):
|
||||
return None
|
||||
|
||||
|
||||
client_id = payload.get("client_id")
|
||||
if client_id:
|
||||
return UUID(client_id)
|
||||
|
||||
return None
|
||||
|
||||
return None
|
||||
|
55
src/main.py
55
src/main.py
@ -1,35 +1,30 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the root directory to PYTHONPATH
|
||||
root_dir = Path(__file__).parent.parent
|
||||
sys.path.append(str(root_dir))
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from src.config.database import engine, Base
|
||||
from src.config.settings import settings
|
||||
from src.utils.logger import setup_logger
|
||||
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from google.adk.sessions import DatabaseSessionService
|
||||
from google.adk.memory import InMemoryMemoryService
|
||||
|
||||
# Initialize service instances
|
||||
session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING)
|
||||
artifacts_service = InMemoryArtifactService()
|
||||
memory_service = InMemoryMemoryService()
|
||||
# Necessary for other modules
|
||||
from src.services.service_providers import session_service # noqa: F401
|
||||
from src.services.service_providers import artifacts_service # noqa: F401
|
||||
from src.services.service_providers import memory_service # noqa: F401
|
||||
|
||||
# Import routers after service initialization to avoid circular imports
|
||||
from src.api.auth_routes import router as auth_router
|
||||
from src.api.admin_routes import router as admin_router
|
||||
from src.api.chat_routes import router as chat_router
|
||||
from src.api.session_routes import router as session_router
|
||||
from src.api.agent_routes import router as agent_router
|
||||
from src.api.contact_routes import router as contact_router
|
||||
from src.api.mcp_server_routes import router as mcp_server_router
|
||||
from src.api.tool_routes import router as tool_router
|
||||
from src.api.client_routes import router as client_router
|
||||
import src.api.auth_routes
|
||||
import src.api.admin_routes
|
||||
import src.api.chat_routes
|
||||
import src.api.session_routes
|
||||
import src.api.agent_routes
|
||||
import src.api.contact_routes
|
||||
import src.api.mcp_server_routes
|
||||
import src.api.tool_routes
|
||||
import src.api.client_routes
|
||||
|
||||
# Add the root directory to PYTHONPATH
|
||||
root_dir = Path(__file__).parent.parent
|
||||
sys.path.append(str(root_dir))
|
||||
|
||||
# Configure logger
|
||||
logger = setup_logger(__name__)
|
||||
@ -52,8 +47,7 @@ app.add_middleware(
|
||||
|
||||
# PostgreSQL configuration
|
||||
POSTGRES_CONNECTION_STRING = os.getenv(
|
||||
"POSTGRES_CONNECTION_STRING",
|
||||
"postgresql://postgres:root@localhost:5432/evo_ai"
|
||||
"POSTGRES_CONNECTION_STRING", "postgresql://postgres:root@localhost:5432/evo_ai"
|
||||
)
|
||||
|
||||
# Create database tables
|
||||
@ -61,6 +55,17 @@ Base.metadata.create_all(bind=engine)
|
||||
|
||||
API_PREFIX = "/api/v1"
|
||||
|
||||
# Define router references
|
||||
auth_router = src.api.auth_routes.router
|
||||
admin_router = src.api.admin_routes.router
|
||||
chat_router = src.api.chat_routes.router
|
||||
session_router = src.api.session_routes.router
|
||||
agent_router = src.api.agent_routes.router
|
||||
contact_router = src.api.contact_routes.router
|
||||
mcp_server_router = src.api.mcp_server_routes.router
|
||||
tool_router = src.api.tool_routes.router
|
||||
client_router = src.api.client_routes.router
|
||||
|
||||
# Include routes
|
||||
app.include_router(auth_router, prefix=API_PREFIX)
|
||||
app.include_router(admin_router, prefix=API_PREFIX)
|
||||
@ -79,5 +84,5 @@ def read_root():
|
||||
"message": "Welcome to Evo AI API",
|
||||
"documentation": "/docs",
|
||||
"version": settings.API_VERSION,
|
||||
"auth": "To access the API, use JWT authentication via '/api/v1/auth/login'"
|
||||
"auth": "To access the API, use JWT authentication via '/api/v1/auth/login'",
|
||||
}
|
||||
|
@ -1,9 +1,20 @@
|
||||
from sqlalchemy import Column, String, UUID, DateTime, ForeignKey, JSON, Text, BigInteger, CheckConstraint, Boolean
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
String,
|
||||
UUID,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
JSON,
|
||||
Text,
|
||||
CheckConstraint,
|
||||
Boolean,
|
||||
)
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship, backref
|
||||
from src.config.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class Client(Base):
|
||||
__tablename__ = "clients"
|
||||
|
||||
@ -13,13 +24,16 @@ class Client(Base):
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
email = Column(String, unique=True, index=True, nullable=False)
|
||||
password_hash = Column(String, nullable=False)
|
||||
client_id = Column(UUID(as_uuid=True), ForeignKey("clients.id", ondelete="CASCADE"), nullable=True)
|
||||
client_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("clients.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
is_active = Column(Boolean, default=False)
|
||||
is_admin = Column(Boolean, default=False)
|
||||
email_verified = Column(Boolean, default=False)
|
||||
@ -29,9 +43,12 @@ class User(Base):
|
||||
password_reset_expiry = Column(DateTime(timezone=True), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
# Relationship with Client (One-to-One, optional for administrators)
|
||||
client = relationship("Client", backref=backref("user", uselist=False, cascade="all, delete-orphan"))
|
||||
client = relationship(
|
||||
"Client", backref=backref("user", uselist=False, cascade="all, delete-orphan")
|
||||
)
|
||||
|
||||
|
||||
class Contact(Base):
|
||||
__tablename__ = "contacts"
|
||||
@ -44,6 +61,7 @@ class Contact(Base):
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
class Agent(Base):
|
||||
__tablename__ = "agents"
|
||||
|
||||
@ -60,21 +78,30 @@ class Agent(Base):
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint("type IN ('llm', 'sequential', 'parallel', 'loop')", name='check_agent_type'),
|
||||
CheckConstraint(
|
||||
"type IN ('llm', 'sequential', 'parallel', 'loop')", name="check_agent_type"
|
||||
),
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""Converts the object to a dictionary, converting UUIDs to strings"""
|
||||
result = {}
|
||||
for key, value in self.__dict__.items():
|
||||
if key.startswith('_'):
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
if isinstance(value, uuid.UUID):
|
||||
result[key] = str(value)
|
||||
elif isinstance(value, dict):
|
||||
result[key] = self._convert_dict(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [self._convert_dict(item) if isinstance(item, dict) else str(item) if isinstance(item, uuid.UUID) else item for item in value]
|
||||
result[key] = [
|
||||
(
|
||||
self._convert_dict(item)
|
||||
if isinstance(item, dict)
|
||||
else str(item) if isinstance(item, uuid.UUID) else item
|
||||
)
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
@ -88,11 +115,19 @@ class Agent(Base):
|
||||
elif isinstance(value, dict):
|
||||
result[key] = self._convert_dict(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [self._convert_dict(item) if isinstance(item, dict) else str(item) if isinstance(item, uuid.UUID) else item for item in value]
|
||||
result[key] = [
|
||||
(
|
||||
self._convert_dict(item)
|
||||
if isinstance(item, dict)
|
||||
else str(item) if isinstance(item, uuid.UUID) else item
|
||||
)
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
class MCPServer(Base):
|
||||
__tablename__ = "mcp_servers"
|
||||
|
||||
@ -105,11 +140,14 @@ class MCPServer(Base):
|
||||
type = Column(String, nullable=False, default="official")
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint("type IN ('official', 'community')", name='check_mcp_server_type'),
|
||||
CheckConstraint(
|
||||
"type IN ('official', 'community')", name="check_mcp_server_type"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Tool(Base):
|
||||
__tablename__ = "tools"
|
||||
|
||||
@ -121,11 +159,12 @@ class Tool(Base):
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
class Session(Base):
|
||||
__tablename__ = "sessions"
|
||||
# The directive below makes Alembic ignore this table in migrations
|
||||
__table_args__ = {'extend_existing': True, 'info': {'skip_autogenerate': True}}
|
||||
|
||||
__table_args__ = {"extend_existing": True, "info": {"skip_autogenerate": True}}
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
app_name = Column(String)
|
||||
user_id = Column(String)
|
||||
@ -133,11 +172,14 @@ class Session(Base):
|
||||
create_time = Column(DateTime(timezone=True))
|
||||
update_time = Column(DateTime(timezone=True))
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
action = Column(String, nullable=False)
|
||||
resource_type = Column(String, nullable=False)
|
||||
resource_id = Column(String, nullable=True)
|
||||
@ -145,6 +187,6 @@ class AuditLog(Base):
|
||||
ip_address = Column(String, nullable=True)
|
||||
user_agent = Column(String, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
|
||||
# Relationship with User
|
||||
user = relationship("User", backref="audit_logs")
|
||||
user = relationship("User", backref="audit_logs")
|
||||
|
@ -1,26 +1,38 @@
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from typing import List, Optional, Dict, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
"""Configuration of a tool"""
|
||||
|
||||
id: UUID
|
||||
envs: Dict[str, str] = Field(default_factory=dict, description="Environment variables of the tool")
|
||||
envs: Dict[str, str] = Field(
|
||||
default_factory=dict, description="Environment variables of the tool"
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class MCPServerConfig(BaseModel):
|
||||
"""Configuration of an MCP server"""
|
||||
|
||||
id: UUID
|
||||
envs: Dict[str, str] = Field(default_factory=dict, description="Environment variables of the server")
|
||||
tools: List[str] = Field(default_factory=list, description="List of tools of the server")
|
||||
envs: Dict[str, str] = Field(
|
||||
default_factory=dict, description="Environment variables of the server"
|
||||
)
|
||||
tools: List[str] = Field(
|
||||
default_factory=list, description="List of tools of the server"
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class HTTPToolParameter(BaseModel):
|
||||
"""Parameter of an HTTP tool"""
|
||||
|
||||
type: str
|
||||
required: bool
|
||||
description: str
|
||||
@ -28,8 +40,10 @@ class HTTPToolParameter(BaseModel):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class HTTPToolParameters(BaseModel):
|
||||
"""Parameters of an HTTP tool"""
|
||||
|
||||
path_params: Optional[Dict[str, str]] = None
|
||||
query_params: Optional[Dict[str, Union[str, List[str]]]] = None
|
||||
body_params: Optional[Dict[str, HTTPToolParameter]] = None
|
||||
@ -37,8 +51,10 @@ class HTTPToolParameters(BaseModel):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class HTTPToolErrorHandling(BaseModel):
|
||||
"""Configuration of error handling"""
|
||||
|
||||
timeout: int
|
||||
retry_count: int
|
||||
fallback_response: Dict[str, str]
|
||||
@ -46,8 +62,10 @@ class HTTPToolErrorHandling(BaseModel):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class HTTPTool(BaseModel):
|
||||
"""Configuration of an HTTP tool"""
|
||||
|
||||
name: str
|
||||
method: str
|
||||
values: Dict[str, str]
|
||||
@ -60,42 +78,72 @@ class HTTPTool(BaseModel):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CustomTools(BaseModel):
|
||||
"""Configuration of custom tools"""
|
||||
http_tools: List[HTTPTool] = Field(default_factory=list, description="List of HTTP tools")
|
||||
|
||||
http_tools: List[HTTPTool] = Field(
|
||||
default_factory=list, description="List of HTTP tools"
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""Configuration for LLM agents"""
|
||||
tools: Optional[List[ToolConfig]] = Field(default=None, description="List of available tools")
|
||||
custom_tools: Optional[CustomTools] = Field(default=None, description="Custom tools")
|
||||
mcp_servers: Optional[List[MCPServerConfig]] = Field(default=None, description="List of MCP servers")
|
||||
sub_agents: Optional[List[UUID]] = Field(default=None, description="List of IDs of sub-agents")
|
||||
|
||||
tools: Optional[List[ToolConfig]] = Field(
|
||||
default=None, description="List of available tools"
|
||||
)
|
||||
custom_tools: Optional[CustomTools] = Field(
|
||||
default=None, description="Custom tools"
|
||||
)
|
||||
mcp_servers: Optional[List[MCPServerConfig]] = Field(
|
||||
default=None, description="List of MCP servers"
|
||||
)
|
||||
sub_agents: Optional[List[UUID]] = Field(
|
||||
default=None, description="List of IDs of sub-agents"
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SequentialConfig(BaseModel):
|
||||
"""Configuration for sequential agents"""
|
||||
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents in execution order")
|
||||
|
||||
sub_agents: List[UUID] = Field(
|
||||
..., description="List of IDs of sub-agents in execution order"
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ParallelConfig(BaseModel):
|
||||
"""Configuration for parallel agents"""
|
||||
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents for parallel execution")
|
||||
|
||||
sub_agents: List[UUID] = Field(
|
||||
..., description="List of IDs of sub-agents for parallel execution"
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class LoopConfig(BaseModel):
|
||||
"""Configuration for loop agents"""
|
||||
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents for loop execution")
|
||||
max_iterations: Optional[int] = Field(default=None, description="Maximum number of iterations")
|
||||
condition: Optional[str] = Field(default=None, description="Condition to stop the loop")
|
||||
|
||||
sub_agents: List[UUID] = Field(
|
||||
..., description="List of IDs of sub-agents for loop execution"
|
||||
)
|
||||
max_iterations: Optional[int] = Field(
|
||||
default=None, description="Maximum number of iterations"
|
||||
)
|
||||
condition: Optional[str] = Field(
|
||||
default=None, description="Condition to stop the loop"
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
from_attributes = True
|
||||
|
@ -3,30 +3,38 @@ from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class AuditLogBase(BaseModel):
|
||||
"""Base schema for audit log"""
|
||||
|
||||
action: str
|
||||
resource_type: str
|
||||
resource_id: Optional[str] = None
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class AuditLogCreate(AuditLogBase):
|
||||
"""Schema for creating audit log"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AuditLogResponse(AuditLogBase):
|
||||
"""Schema for audit log response"""
|
||||
|
||||
id: UUID
|
||||
user_id: Optional[UUID] = None
|
||||
ip_address: Optional[str] = None
|
||||
user_agent: Optional[str] = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AuditLogFilter(BaseModel):
|
||||
"""Schema for audit log search filters"""
|
||||
|
||||
user_id: Optional[UUID] = None
|
||||
action: Optional[str] = None
|
||||
resource_type: Optional[str] = None
|
||||
@ -34,4 +42,4 @@ class AuditLogFilter(BaseModel):
|
||||
start_date: Optional[datetime] = None
|
||||
end_date: Optional[datetime] = None
|
||||
skip: Optional[int] = Field(0, ge=0)
|
||||
limit: Optional[int] = Field(100, ge=1, le=1000)
|
||||
limit: Optional[int] = Field(100, ge=1, le=1000)
|
||||
|
@ -1,21 +1,33 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""Schema for chat requests"""
|
||||
agent_id: str = Field(..., description="ID of the agent that will process the message")
|
||||
contact_id: str = Field(..., description="ID of the contact that will process the message")
|
||||
|
||||
agent_id: str = Field(
|
||||
..., description="ID of the agent that will process the message"
|
||||
)
|
||||
contact_id: str = Field(
|
||||
..., description="ID of the contact that will process the message"
|
||||
)
|
||||
message: str = Field(..., description="User message")
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""Schema for chat responses"""
|
||||
|
||||
response: str = Field(..., description="Agent response")
|
||||
status: str = Field(..., description="Operation status")
|
||||
error: Optional[str] = Field(None, description="Error message, if there is one")
|
||||
timestamp: str = Field(..., description="Timestamp of the response")
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Schema for error responses"""
|
||||
|
||||
error: str = Field(..., description="Error message")
|
||||
status_code: int = Field(..., description="HTTP status code of the error")
|
||||
details: Optional[Dict[str, Any]] = Field(None, description="Additional error details")
|
||||
details: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Additional error details"
|
||||
)
|
||||
|
@ -4,15 +4,18 @@ from datetime import datetime
|
||||
from uuid import UUID
|
||||
import uuid
|
||||
import re
|
||||
from .agent_config import LLMConfig, SequentialConfig, ParallelConfig, LoopConfig
|
||||
from src.schemas.agent_config import LLMConfig
|
||||
|
||||
|
||||
class ClientBase(BaseModel):
|
||||
name: str
|
||||
email: Optional[EmailStr] = None
|
||||
|
||||
|
||||
class ClientCreate(ClientBase):
|
||||
pass
|
||||
|
||||
|
||||
class Client(ClientBase):
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
@ -20,14 +23,17 @@ class Client(ClientBase):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ContactBase(BaseModel):
|
||||
ext_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
meta: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ContactCreate(ContactBase):
|
||||
client_id: UUID
|
||||
|
||||
|
||||
class Contact(ContactBase):
|
||||
id: UUID
|
||||
client_id: UUID
|
||||
@ -35,67 +41,80 @@ class Contact(ContactBase):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AgentBase(BaseModel):
|
||||
name: str = Field(..., description="Agent name (no spaces or special characters)")
|
||||
description: Optional[str] = Field(None, description="Agent description")
|
||||
type: str = Field(..., description="Agent type (llm, sequential, parallel, loop)")
|
||||
model: Optional[str] = Field(None, description="Agent model (required only for llm type)")
|
||||
api_key: Optional[str] = Field(None, description="Agent API Key (required only for llm type)")
|
||||
model: Optional[str] = Field(
|
||||
None, description="Agent model (required only for llm type)"
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
None, description="Agent API Key (required only for llm type)"
|
||||
)
|
||||
instruction: Optional[str] = None
|
||||
config: Union[LLMConfig, Dict[str, Any]] = Field(..., description="Agent configuration based on type")
|
||||
config: Union[LLMConfig, Dict[str, Any]] = Field(
|
||||
..., description="Agent configuration based on type"
|
||||
)
|
||||
|
||||
@validator('name')
|
||||
@validator("name")
|
||||
def validate_name(cls, v):
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', v):
|
||||
raise ValueError('Agent name cannot contain spaces or special characters')
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
||||
raise ValueError("Agent name cannot contain spaces or special characters")
|
||||
return v
|
||||
|
||||
@validator('type')
|
||||
@validator("type")
|
||||
def validate_type(cls, v):
|
||||
if v not in ['llm', 'sequential', 'parallel', 'loop']:
|
||||
raise ValueError('Invalid agent type. Must be: llm, sequential, parallel or loop')
|
||||
if v not in ["llm", "sequential", "parallel", "loop"]:
|
||||
raise ValueError(
|
||||
"Invalid agent type. Must be: llm, sequential, parallel or loop"
|
||||
)
|
||||
return v
|
||||
|
||||
@validator('model')
|
||||
@validator("model")
|
||||
def validate_model(cls, v, values):
|
||||
if 'type' in values and values['type'] == 'llm' and not v:
|
||||
raise ValueError('Model is required for llm type agents')
|
||||
if "type" in values and values["type"] == "llm" and not v:
|
||||
raise ValueError("Model is required for llm type agents")
|
||||
return v
|
||||
|
||||
@validator('api_key')
|
||||
@validator("api_key")
|
||||
def validate_api_key(cls, v, values):
|
||||
if 'type' in values and values['type'] == 'llm' and not v:
|
||||
raise ValueError('API Key is required for llm type agents')
|
||||
if "type" in values and values["type"] == "llm" and not v:
|
||||
raise ValueError("API Key is required for llm type agents")
|
||||
return v
|
||||
|
||||
@validator('config')
|
||||
@validator("config")
|
||||
def validate_config(cls, v, values):
|
||||
if 'type' not in values:
|
||||
if "type" not in values:
|
||||
return v
|
||||
|
||||
if values['type'] == 'llm':
|
||||
|
||||
if values["type"] == "llm":
|
||||
if isinstance(v, dict):
|
||||
try:
|
||||
# Convert the dictionary to LLMConfig
|
||||
v = LLMConfig(**v)
|
||||
except Exception as e:
|
||||
raise ValueError(f'Invalid LLM configuration for agent: {str(e)}')
|
||||
raise ValueError(f"Invalid LLM configuration for agent: {str(e)}")
|
||||
elif not isinstance(v, LLMConfig):
|
||||
raise ValueError('Invalid LLM configuration for agent')
|
||||
elif values['type'] in ['sequential', 'parallel', 'loop']:
|
||||
raise ValueError("Invalid LLM configuration for agent")
|
||||
elif values["type"] in ["sequential", "parallel", "loop"]:
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError(f'Invalid configuration for agent {values["type"]}')
|
||||
if 'sub_agents' not in v:
|
||||
if "sub_agents" not in v:
|
||||
raise ValueError(f'Agent {values["type"]} must have sub_agents')
|
||||
if not isinstance(v['sub_agents'], list):
|
||||
raise ValueError('sub_agents must be a list')
|
||||
if not v['sub_agents']:
|
||||
raise ValueError(f'Agent {values["type"]} must have at least one sub-agent')
|
||||
if not isinstance(v["sub_agents"], list):
|
||||
raise ValueError("sub_agents must be a list")
|
||||
if not v["sub_agents"]:
|
||||
raise ValueError(
|
||||
f'Agent {values["type"]} must have at least one sub-agent'
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class AgentCreate(AgentBase):
|
||||
client_id: UUID
|
||||
|
||||
|
||||
class Agent(AgentBase):
|
||||
id: UUID
|
||||
client_id: UUID
|
||||
@ -105,6 +124,7 @@ class Agent(AgentBase):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class MCPServerBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
@ -113,9 +133,11 @@ class MCPServerBase(BaseModel):
|
||||
tools: List[str] = Field(default_factory=list)
|
||||
type: str = Field(default="official")
|
||||
|
||||
|
||||
class MCPServerCreate(MCPServerBase):
|
||||
pass
|
||||
|
||||
|
||||
class MCPServer(MCPServerBase):
|
||||
id: uuid.UUID
|
||||
created_at: datetime
|
||||
@ -124,19 +146,22 @@ class MCPServer(MCPServerBase):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ToolBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
config_json: Dict[str, Any] = Field(default_factory=dict)
|
||||
environments: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ToolCreate(ToolBase):
|
||||
pass
|
||||
|
||||
|
||||
class Tool(ToolBase):
|
||||
id: uuid.UUID
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
from_attributes = True
|
||||
|
@ -1,23 +1,28 @@
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
password: str
|
||||
name: str # For client creation
|
||||
|
||||
|
||||
class AdminUserCreate(UserBase):
|
||||
password: str
|
||||
name: str
|
||||
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class UserResponse(UserBase):
|
||||
id: UUID
|
||||
client_id: Optional[UUID] = None
|
||||
@ -25,26 +30,31 @@ class UserResponse(UserBase):
|
||||
email_verified: bool
|
||||
is_admin: bool
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
sub: str # user email
|
||||
exp: datetime
|
||||
is_admin: bool
|
||||
client_id: Optional[UUID] = None
|
||||
|
||||
|
||||
|
||||
class PasswordReset(BaseModel):
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
|
||||
|
||||
class ForgotPassword(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
||||
message: str
|
||||
|
@ -1 +1 @@
|
||||
from .agent_runner import run_agent
|
||||
from .agent_runner import run_agent
|
||||
|
@ -13,11 +13,10 @@ from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.models import LlmResponse, LlmRequest
|
||||
from google.adk.tools import load_memory
|
||||
|
||||
from typing import Optional
|
||||
import logging
|
||||
import os
|
||||
import requests
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
@ -83,7 +82,7 @@ def before_model_callback(
|
||||
llm_request.config.system_instruction = modified_text
|
||||
|
||||
logger.debug(
|
||||
f"📝 System instruction updated with search results and history"
|
||||
"📝 System instruction updated with search results and history"
|
||||
)
|
||||
else:
|
||||
logger.warning("⚠️ No results found in the search")
|
||||
@ -180,11 +179,13 @@ class AgentBuilder:
|
||||
mcp_tools = []
|
||||
mcp_exit_stack = None
|
||||
if agent.config.get("mcp_servers"):
|
||||
mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(agent.config, self.db)
|
||||
mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(
|
||||
agent.config, self.db
|
||||
)
|
||||
|
||||
# Combine all tools
|
||||
all_tools = custom_tools + mcp_tools
|
||||
|
||||
|
||||
now = datetime.now()
|
||||
current_datetime = now.strftime("%d/%m/%Y %H:%M")
|
||||
current_day_of_week = now.strftime("%A")
|
||||
@ -201,10 +202,13 @@ class AgentBuilder:
|
||||
|
||||
# Check if load_memory is enabled
|
||||
# before_model_callback_func = None
|
||||
if agent.config.get("load_memory") == True:
|
||||
if agent.config.get("load_memory"):
|
||||
all_tools.append(load_memory)
|
||||
# before_model_callback_func = before_model_callback
|
||||
formatted_prompt = formatted_prompt + "\n\n<memory_instructions>ALWAYS use the load_memory tool to retrieve knowledge for your context</memory_instructions>\n\n"
|
||||
formatted_prompt = (
|
||||
formatted_prompt
|
||||
+ "\n\n<memory_instructions>ALWAYS use the load_memory tool to retrieve knowledge for your context</memory_instructions>\n\n"
|
||||
)
|
||||
|
||||
return (
|
||||
LlmAgent(
|
||||
|
@ -22,9 +22,7 @@ async def run_agent(
|
||||
db: Session,
|
||||
):
|
||||
try:
|
||||
logger.info(
|
||||
f"Starting execution of agent {agent_id} for contact {contact_id}"
|
||||
)
|
||||
logger.info(f"Starting execution of agent {agent_id} for contact {contact_id}")
|
||||
logger.info(f"Received message: {message}")
|
||||
|
||||
get_root_agent = get_agent(db, agent_id)
|
||||
@ -77,15 +75,15 @@ async def run_agent(
|
||||
if event.is_final_response() and event.content and event.content.parts:
|
||||
final_response_text = event.content.parts[0].text
|
||||
logger.info(f"Final response received: {final_response_text}")
|
||||
|
||||
|
||||
completed_session = session_service.get_session(
|
||||
app_name=agent_id,
|
||||
user_id=contact_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
memory_service.add_session_to_memory(completed_session)
|
||||
|
||||
|
||||
finally:
|
||||
# Ensure the exit_stack is closed correctly
|
||||
if exit_stack:
|
||||
|
@ -216,9 +216,7 @@ async def update_agent(
|
||||
return agent
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error updating agent: {str(e)}"
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=f"Error updating agent: {str(e)}")
|
||||
|
||||
|
||||
def delete_agent(db: Session, agent_id: uuid.UUID) -> bool:
|
||||
|
@ -1,15 +1,15 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from src.models.models import AuditLog, User
|
||||
from src.models.models import AuditLog
|
||||
from datetime import datetime
|
||||
from fastapi import Request
|
||||
from typing import Optional, Dict, Any, List
|
||||
import uuid
|
||||
import logging
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_audit_log(
|
||||
db: Session,
|
||||
user_id: Optional[uuid.UUID],
|
||||
@ -17,11 +17,11 @@ def create_audit_log(
|
||||
resource_type: str,
|
||||
resource_id: Optional[str] = None,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
request: Optional[Request] = None
|
||||
request: Optional[Request] = None,
|
||||
) -> Optional[AuditLog]:
|
||||
"""
|
||||
Create a new audit log
|
||||
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID that performed the action (or None if anonymous)
|
||||
@ -30,25 +30,25 @@ def create_audit_log(
|
||||
resource_id: Resource ID (optional)
|
||||
details: Additional details of the action (optional)
|
||||
request: FastAPI Request object (optional, to get IP and User-Agent)
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[AuditLog]: Created audit log or None in case of error
|
||||
"""
|
||||
try:
|
||||
ip_address = None
|
||||
user_agent = None
|
||||
|
||||
|
||||
if request:
|
||||
ip_address = request.client.host if hasattr(request, 'client') else None
|
||||
ip_address = request.client.host if hasattr(request, "client") else None
|
||||
user_agent = request.headers.get("user-agent")
|
||||
|
||||
|
||||
# Convert details to serializable format
|
||||
if details:
|
||||
# Convert UUIDs to strings
|
||||
for key, value in details.items():
|
||||
if isinstance(value, uuid.UUID):
|
||||
details[key] = str(value)
|
||||
|
||||
|
||||
audit_log = AuditLog(
|
||||
user_id=user_id,
|
||||
action=action,
|
||||
@ -56,20 +56,20 @@ def create_audit_log(
|
||||
resource_id=str(resource_id) if resource_id else None,
|
||||
details=details,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
|
||||
db.add(audit_log)
|
||||
db.commit()
|
||||
db.refresh(audit_log)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Audit log created: {action} in {resource_type}" +
|
||||
(f" (ID: {resource_id})" if resource_id else "")
|
||||
f"Audit log created: {action} in {resource_type}"
|
||||
+ (f" (ID: {resource_id})" if resource_id else "")
|
||||
)
|
||||
|
||||
|
||||
return audit_log
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error creating audit log: {str(e)}")
|
||||
@ -78,6 +78,7 @@ def create_audit_log(
|
||||
logger.error(f"Unexpected error creating audit log: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def get_audit_logs(
|
||||
db: Session,
|
||||
skip: int = 0,
|
||||
@ -87,11 +88,11 @@ def get_audit_logs(
|
||||
resource_type: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
end_date: Optional[datetime] = None,
|
||||
) -> List[AuditLog]:
|
||||
"""
|
||||
Get audit logs with optional filters
|
||||
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
@ -102,35 +103,35 @@ def get_audit_logs(
|
||||
resource_id: Filter by resource ID
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
|
||||
|
||||
Returns:
|
||||
List[AuditLog]: List of audit logs
|
||||
"""
|
||||
query = db.query(AuditLog)
|
||||
|
||||
|
||||
# Apply filters, if provided
|
||||
if user_id:
|
||||
query = query.filter(AuditLog.user_id == user_id)
|
||||
|
||||
|
||||
if action:
|
||||
query = query.filter(AuditLog.action == action)
|
||||
|
||||
|
||||
if resource_type:
|
||||
query = query.filter(AuditLog.resource_type == resource_type)
|
||||
|
||||
|
||||
if resource_id:
|
||||
query = query.filter(AuditLog.resource_id == resource_id)
|
||||
|
||||
|
||||
if start_date:
|
||||
query = query.filter(AuditLog.created_at >= start_date)
|
||||
|
||||
|
||||
if end_date:
|
||||
query = query.filter(AuditLog.created_at <= end_date)
|
||||
|
||||
|
||||
# Order by creation date (most recent first)
|
||||
query = query.order_by(AuditLog.created_at.desc())
|
||||
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
return query.all()
|
||||
|
||||
return query.all()
|
||||
|
@ -16,17 +16,20 @@ logger = logging.getLogger(__name__)
|
||||
# Define OAuth2 authentication scheme with password flow
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> User:
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
Get the current user from the JWT token
|
||||
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
db: Database session
|
||||
|
||||
|
||||
Returns:
|
||||
User: Current user
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If the token is invalid or the user is not found
|
||||
"""
|
||||
@ -35,103 +38,108 @@ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = De
|
||||
detail="Invalid credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# Decode the token
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM]
|
||||
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
|
||||
|
||||
# Extract token data
|
||||
email: str = payload.get("sub")
|
||||
if email is None:
|
||||
logger.warning("Token without email (sub)")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
# Check if the token has expired
|
||||
exp = payload.get("exp")
|
||||
if exp is None or datetime.fromtimestamp(exp) < datetime.utcnow():
|
||||
logger.warning(f"Token expired for {email}")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
# Create TokenData object
|
||||
token_data = TokenData(
|
||||
sub=email,
|
||||
exp=datetime.fromtimestamp(exp),
|
||||
is_admin=payload.get("is_admin", False),
|
||||
client_id=payload.get("client_id")
|
||||
client_id=payload.get("client_id"),
|
||||
)
|
||||
|
||||
|
||||
except JWTError as e:
|
||||
logger.error(f"Error decoding JWT token: {str(e)}")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
# Search for user in the database
|
||||
user = get_user_by_email(db, email=token_data.sub)
|
||||
if user is None:
|
||||
logger.warning(f"User not found for email: {token_data.sub}")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
if not user.is_active:
|
||||
logger.warning(f"Attempt to access inactive user: {user.email}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
|
||||
|
||||
return user
|
||||
|
||||
async def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
"""
|
||||
Check if the current user is active
|
||||
|
||||
|
||||
Args:
|
||||
current_user: Current user
|
||||
|
||||
|
||||
Returns:
|
||||
User: Current user if active
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If the user is not active
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
logger.warning(f"Attempt to access inactive user: {current_user.email}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
return current_user
|
||||
|
||||
async def get_current_admin_user(current_user: User = Depends(get_current_user)) -> User:
|
||||
|
||||
async def get_current_admin_user(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
"""
|
||||
Check if the current user is an administrator
|
||||
|
||||
|
||||
Args:
|
||||
current_user: Current user
|
||||
|
||||
|
||||
Returns:
|
||||
User: Current user if administrator
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If the user is not an administrator
|
||||
"""
|
||||
if not current_user.is_admin:
|
||||
logger.warning(f"Attempt to access admin by non-admin user: {current_user.email}")
|
||||
logger.warning(
|
||||
f"Attempt to access admin by non-admin user: {current_user.email}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. Restricted to administrators."
|
||||
detail="Access denied. Restricted to administrators.",
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def create_access_token(user: User) -> str:
|
||||
"""
|
||||
Create a JWT access token for the user
|
||||
|
||||
|
||||
Args:
|
||||
user: User for which to create the token
|
||||
|
||||
|
||||
Returns:
|
||||
str: JWT token
|
||||
"""
|
||||
@ -140,10 +148,10 @@ def create_access_token(user: User) -> str:
|
||||
"sub": user.email,
|
||||
"is_admin": user.is_admin,
|
||||
}
|
||||
|
||||
|
||||
# Include client_id only if not administrator and client_id is set
|
||||
if not user.is_admin and user.client_id:
|
||||
token_data["client_id"] = str(user.client_id)
|
||||
|
||||
|
||||
# Create token
|
||||
return create_jwt_token(token_data)
|
||||
return create_jwt_token(token_data)
|
||||
|
@ -11,6 +11,7 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]:
|
||||
"""Search for a client by ID"""
|
||||
try:
|
||||
@ -23,9 +24,10 @@ def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]:
|
||||
logger.error(f"Error searching for client {client_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error searching for client"
|
||||
detail="Error searching for client",
|
||||
)
|
||||
|
||||
|
||||
def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]:
|
||||
"""Search for all clients with pagination"""
|
||||
try:
|
||||
@ -34,9 +36,10 @@ def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]:
|
||||
logger.error(f"Error searching for clients: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error searching for clients"
|
||||
detail="Error searching for clients",
|
||||
)
|
||||
|
||||
|
||||
def create_client(db: Session, client: ClientCreate) -> Client:
|
||||
"""Create a new client"""
|
||||
try:
|
||||
@ -51,19 +54,22 @@ def create_client(db: Session, client: ClientCreate) -> Client:
|
||||
logger.error(f"Error creating client: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error creating client"
|
||||
detail="Error creating client",
|
||||
)
|
||||
|
||||
def update_client(db: Session, client_id: uuid.UUID, client: ClientCreate) -> Optional[Client]:
|
||||
|
||||
def update_client(
|
||||
db: Session, client_id: uuid.UUID, client: ClientCreate
|
||||
) -> Optional[Client]:
|
||||
"""Updates an existing client"""
|
||||
try:
|
||||
db_client = get_client(db, client_id)
|
||||
if not db_client:
|
||||
return None
|
||||
|
||||
|
||||
for key, value in client.model_dump().items():
|
||||
setattr(db_client, key, value)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_client)
|
||||
logger.info(f"Client updated successfully: {client_id}")
|
||||
@ -73,16 +79,17 @@ def update_client(db: Session, client_id: uuid.UUID, client: ClientCreate) -> Op
|
||||
logger.error(f"Error updating client {client_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error updating client"
|
||||
detail="Error updating client",
|
||||
)
|
||||
|
||||
|
||||
def delete_client(db: Session, client_id: uuid.UUID) -> bool:
|
||||
"""Removes a client"""
|
||||
try:
|
||||
db_client = get_client(db, client_id)
|
||||
if not db_client:
|
||||
return False
|
||||
|
||||
|
||||
db.delete(db_client)
|
||||
db.commit()
|
||||
logger.info(f"Client removed successfully: {client_id}")
|
||||
@ -92,18 +99,21 @@ def delete_client(db: Session, client_id: uuid.UUID) -> bool:
|
||||
logger.error(f"Error removing client {client_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error removing client"
|
||||
detail="Error removing client",
|
||||
)
|
||||
|
||||
def create_client_with_user(db: Session, client_data: ClientCreate, user_data: UserCreate) -> Tuple[Optional[Client], str]:
|
||||
|
||||
def create_client_with_user(
|
||||
db: Session, client_data: ClientCreate, user_data: UserCreate
|
||||
) -> Tuple[Optional[Client], str]:
|
||||
"""
|
||||
Creates a new client with an associated user
|
||||
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_data: Client data to be created
|
||||
user_data: User data to be created
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[Client], str]: Tuple with the created client (or None in case of error) and status message
|
||||
"""
|
||||
@ -112,27 +122,27 @@ def create_client_with_user(db: Session, client_data: ClientCreate, user_data: U
|
||||
client = Client(**client_data.model_dump())
|
||||
db.add(client)
|
||||
db.flush() # Get client ID without committing the transaction
|
||||
|
||||
|
||||
# Use client ID to create the associated user
|
||||
user, message = create_user(db, user_data, is_admin=False, client_id=client.id)
|
||||
|
||||
|
||||
if not user:
|
||||
# If there was an error creating the user, rollback
|
||||
db.rollback()
|
||||
logger.error(f"Error creating user for client: {message}")
|
||||
return None, f"Error creating user: {message}"
|
||||
|
||||
|
||||
# If everything went well, commit the transaction
|
||||
db.commit()
|
||||
logger.info(f"Client and user created successfully: {client.id}")
|
||||
return client, "Client and user created successfully"
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error creating client with user: {str(e)}")
|
||||
return None, f"Error creating client with user: {str(e)}"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Unexpected error creating client with user: {str(e)}")
|
||||
return None, f"Unexpected error: {str(e)}"
|
||||
return None, f"Unexpected error: {str(e)}"
|
||||
|
@ -9,6 +9,7 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]:
|
||||
"""Search for a contact by ID"""
|
||||
try:
|
||||
@ -21,20 +22,30 @@ def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]:
|
||||
logger.error(f"Error searching for contact {contact_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error searching for contact"
|
||||
detail="Error searching for contact",
|
||||
)
|
||||
|
||||
def get_contacts_by_client(db: Session, client_id: uuid.UUID, skip: int = 0, limit: int = 100) -> List[Contact]:
|
||||
|
||||
def get_contacts_by_client(
|
||||
db: Session, client_id: uuid.UUID, skip: int = 0, limit: int = 100
|
||||
) -> List[Contact]:
|
||||
"""Search for contacts of a client with pagination"""
|
||||
try:
|
||||
return db.query(Contact).filter(Contact.client_id == client_id).offset(skip).limit(limit).all()
|
||||
return (
|
||||
db.query(Contact)
|
||||
.filter(Contact.client_id == client_id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error searching for contacts of client {client_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error searching for contacts"
|
||||
detail="Error searching for contacts",
|
||||
)
|
||||
|
||||
|
||||
def create_contact(db: Session, contact: ContactCreate) -> Contact:
|
||||
"""Create a new contact"""
|
||||
try:
|
||||
@ -49,19 +60,22 @@ def create_contact(db: Session, contact: ContactCreate) -> Contact:
|
||||
logger.error(f"Error creating contact: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error creating contact"
|
||||
detail="Error creating contact",
|
||||
)
|
||||
|
||||
def update_contact(db: Session, contact_id: uuid.UUID, contact: ContactCreate) -> Optional[Contact]:
|
||||
|
||||
def update_contact(
|
||||
db: Session, contact_id: uuid.UUID, contact: ContactCreate
|
||||
) -> Optional[Contact]:
|
||||
"""Update an existing contact"""
|
||||
try:
|
||||
db_contact = get_contact(db, contact_id)
|
||||
if not db_contact:
|
||||
return None
|
||||
|
||||
|
||||
for key, value in contact.model_dump().items():
|
||||
setattr(db_contact, key, value)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_contact)
|
||||
logger.info(f"Contact updated successfully: {contact_id}")
|
||||
@ -71,16 +85,17 @@ def update_contact(db: Session, contact_id: uuid.UUID, contact: ContactCreate) -
|
||||
logger.error(f"Error updating contact {contact_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error updating contact"
|
||||
detail="Error updating contact",
|
||||
)
|
||||
|
||||
|
||||
def delete_contact(db: Session, contact_id: uuid.UUID) -> bool:
|
||||
"""Remove a contact"""
|
||||
try:
|
||||
db_contact = get_contact(db, contact_id)
|
||||
if not db_contact:
|
||||
return False
|
||||
|
||||
|
||||
db.delete(db_contact)
|
||||
db.commit()
|
||||
logger.info(f"Contact removed successfully: {contact_id}")
|
||||
@ -90,5 +105,5 @@ def delete_contact(db: Session, contact_id: uuid.UUID) -> bool:
|
||||
logger.error(f"Error removing contact {contact_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error removing contact"
|
||||
)
|
||||
detail="Error removing contact",
|
||||
)
|
||||
|
@ -6,6 +6,7 @@ from src.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class CustomToolBuilder:
|
||||
def __init__(self):
|
||||
self.tools = []
|
||||
@ -53,7 +54,9 @@ class CustomToolBuilder:
|
||||
|
||||
# Adds default values to query params if they are not present
|
||||
for param, value in values.items():
|
||||
if param not in query_params and param not in parameters.get("path_params", {}):
|
||||
if param not in query_params and param not in parameters.get(
|
||||
"path_params", {}
|
||||
):
|
||||
query_params[param] = value
|
||||
|
||||
# Processa body parameters
|
||||
@ -64,7 +67,11 @@ class CustomToolBuilder:
|
||||
|
||||
# Adds default values to body if they are not present
|
||||
for param, value in values.items():
|
||||
if param not in body_data and param not in query_params and param not in parameters.get("path_params", {}):
|
||||
if (
|
||||
param not in body_data
|
||||
and param not in query_params
|
||||
and param not in parameters.get("path_params", {})
|
||||
):
|
||||
body_data[param] = value
|
||||
|
||||
# Makes the HTTP request
|
||||
@ -74,7 +81,7 @@ class CustomToolBuilder:
|
||||
headers=processed_headers,
|
||||
params=query_params,
|
||||
json=body_data if body_data else None,
|
||||
timeout=error_handling.get("timeout", 30)
|
||||
timeout=error_handling.get("timeout", 30),
|
||||
)
|
||||
|
||||
if response.status_code >= 400:
|
||||
@ -87,30 +94,34 @@ class CustomToolBuilder:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {name}: {str(e)}")
|
||||
return json.dumps(error_handling.get("fallback_response", {
|
||||
"error": "tool_execution_error",
|
||||
"message": str(e)
|
||||
}))
|
||||
return json.dumps(
|
||||
error_handling.get(
|
||||
"fallback_response",
|
||||
{"error": "tool_execution_error", "message": str(e)},
|
||||
)
|
||||
)
|
||||
|
||||
# Adds dynamic docstring based on the configuration
|
||||
param_docs = []
|
||||
|
||||
|
||||
# Adds path parameters
|
||||
for param, value in parameters.get("path_params", {}).items():
|
||||
param_docs.append(f"{param}: {value}")
|
||||
|
||||
|
||||
# Adds query parameters
|
||||
for param, value in parameters.get("query_params", {}).items():
|
||||
if isinstance(value, list):
|
||||
param_docs.append(f"{param}: List[{', '.join(value)}]")
|
||||
else:
|
||||
param_docs.append(f"{param}: {value}")
|
||||
|
||||
|
||||
# Adds body parameters
|
||||
for param, param_config in parameters.get("body_params", {}).items():
|
||||
required = "Required" if param_config.get("required", False) else "Optional"
|
||||
param_docs.append(f"{param} ({param_config['type']}, {required}): {param_config['description']}")
|
||||
|
||||
param_docs.append(
|
||||
f"{param} ({param_config['type']}, {required}): {param_config['description']}"
|
||||
)
|
||||
|
||||
# Adds default values
|
||||
if values:
|
||||
param_docs.append("\nDefault values:")
|
||||
@ -119,10 +130,10 @@ class CustomToolBuilder:
|
||||
|
||||
http_tool.__doc__ = f"""
|
||||
{description}
|
||||
|
||||
|
||||
Parameters:
|
||||
{chr(10).join(param_docs)}
|
||||
|
||||
|
||||
Returns:
|
||||
String containing the response in JSON format
|
||||
"""
|
||||
@ -140,4 +151,4 @@ class CustomToolBuilder:
|
||||
for http_tool_config in tools_config.get("http_tools", []):
|
||||
self.tools.append(self._create_http_tool(http_tool_config))
|
||||
|
||||
return self.tools
|
||||
return self.tools
|
||||
|
@ -16,17 +16,18 @@ os.makedirs(templates_dir, exist_ok=True)
|
||||
# Configure Jinja2 with the templates directory
|
||||
env = Environment(
|
||||
loader=FileSystemLoader(templates_dir),
|
||||
autoescape=select_autoescape(['html', 'xml'])
|
||||
autoescape=select_autoescape(["html", "xml"]),
|
||||
)
|
||||
|
||||
|
||||
def _render_template(template_name: str, context: dict) -> str:
|
||||
"""
|
||||
Render a template with the provided data
|
||||
|
||||
|
||||
Args:
|
||||
template_name: Template file name
|
||||
context: Data to render in the template
|
||||
|
||||
|
||||
Returns:
|
||||
str: Rendered HTML
|
||||
"""
|
||||
@ -37,14 +38,15 @@ def _render_template(template_name: str, context: dict) -> str:
|
||||
logger.error(f"Error rendering template '{template_name}': {str(e)}")
|
||||
return f"<p>Could not display email content. Please access {context.get('verification_link', '') or context.get('reset_link', '')}</p>"
|
||||
|
||||
|
||||
def send_verification_email(email: str, token: str) -> bool:
|
||||
"""
|
||||
Send a verification email to the user
|
||||
|
||||
|
||||
Args:
|
||||
email: Recipient's email
|
||||
token: Email verification token
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if the email was sent successfully, False otherwise
|
||||
"""
|
||||
@ -53,39 +55,47 @@ def send_verification_email(email: str, token: str) -> bool:
|
||||
from_email = Email(settings.EMAIL_FROM)
|
||||
to_email = To(email)
|
||||
subject = "Email Verification - Evo AI"
|
||||
|
||||
|
||||
verification_link = f"{settings.APP_URL}/auth/verify-email/{token}"
|
||||
|
||||
html_content = _render_template('verification_email', {
|
||||
'verification_link': verification_link,
|
||||
'user_name': email.split('@')[0], # Use part of the email as temporary name
|
||||
'current_year': datetime.now().year
|
||||
})
|
||||
|
||||
|
||||
html_content = _render_template(
|
||||
"verification_email",
|
||||
{
|
||||
"verification_link": verification_link,
|
||||
"user_name": email.split("@")[
|
||||
0
|
||||
], # Use part of the email as temporary name
|
||||
"current_year": datetime.now().year,
|
||||
},
|
||||
)
|
||||
|
||||
content = Content("text/html", html_content)
|
||||
|
||||
|
||||
mail = Mail(from_email, to_email, subject, content)
|
||||
response = sg.client.mail.send.post(request_body=mail.get())
|
||||
|
||||
|
||||
if response.status_code >= 200 and response.status_code < 300:
|
||||
logger.info(f"Verification email sent to {email}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to send verification email to {email}. Status: {response.status_code}")
|
||||
logger.error(
|
||||
f"Failed to send verification email to {email}. Status: {response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending verification email to {email}: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def send_password_reset_email(email: str, token: str) -> bool:
|
||||
"""
|
||||
Send a password reset email to the user
|
||||
|
||||
|
||||
Args:
|
||||
email: Recipient's email
|
||||
token: Password reset token
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if the email was sent successfully, False otherwise
|
||||
"""
|
||||
@ -94,39 +104,47 @@ def send_password_reset_email(email: str, token: str) -> bool:
|
||||
from_email = Email(settings.EMAIL_FROM)
|
||||
to_email = To(email)
|
||||
subject = "Password Reset - Evo AI"
|
||||
|
||||
|
||||
reset_link = f"{settings.APP_URL}/reset-password?token={token}"
|
||||
|
||||
html_content = _render_template('password_reset', {
|
||||
'reset_link': reset_link,
|
||||
'user_name': email.split('@')[0], # Use part of the email as temporary name
|
||||
'current_year': datetime.now().year
|
||||
})
|
||||
|
||||
|
||||
html_content = _render_template(
|
||||
"password_reset",
|
||||
{
|
||||
"reset_link": reset_link,
|
||||
"user_name": email.split("@")[
|
||||
0
|
||||
], # Use part of the email as temporary name
|
||||
"current_year": datetime.now().year,
|
||||
},
|
||||
)
|
||||
|
||||
content = Content("text/html", html_content)
|
||||
|
||||
|
||||
mail = Mail(from_email, to_email, subject, content)
|
||||
response = sg.client.mail.send.post(request_body=mail.get())
|
||||
|
||||
|
||||
if response.status_code >= 200 and response.status_code < 300:
|
||||
logger.info(f"Password reset email sent to {email}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to send password reset email to {email}. Status: {response.status_code}")
|
||||
logger.error(
|
||||
f"Failed to send password reset email to {email}. Status: {response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending password reset email to {email}: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def send_welcome_email(email: str, user_name: str = None) -> bool:
|
||||
"""
|
||||
Send a welcome email to the user after verification
|
||||
|
||||
|
||||
Args:
|
||||
email: Recipient's email
|
||||
user_name: User's name (optional)
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if the email was sent successfully, False otherwise
|
||||
"""
|
||||
@ -135,41 +153,49 @@ def send_welcome_email(email: str, user_name: str = None) -> bool:
|
||||
from_email = Email(settings.EMAIL_FROM)
|
||||
to_email = To(email)
|
||||
subject = "Welcome to Evo AI"
|
||||
|
||||
|
||||
dashboard_link = f"{settings.APP_URL}/dashboard"
|
||||
|
||||
html_content = _render_template('welcome_email', {
|
||||
'dashboard_link': dashboard_link,
|
||||
'user_name': user_name or email.split('@')[0],
|
||||
'current_year': datetime.now().year
|
||||
})
|
||||
|
||||
|
||||
html_content = _render_template(
|
||||
"welcome_email",
|
||||
{
|
||||
"dashboard_link": dashboard_link,
|
||||
"user_name": user_name or email.split("@")[0],
|
||||
"current_year": datetime.now().year,
|
||||
},
|
||||
)
|
||||
|
||||
content = Content("text/html", html_content)
|
||||
|
||||
|
||||
mail = Mail(from_email, to_email, subject, content)
|
||||
response = sg.client.mail.send.post(request_body=mail.get())
|
||||
|
||||
|
||||
if response.status_code >= 200 and response.status_code < 300:
|
||||
logger.info(f"Welcome email sent to {email}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to send welcome email to {email}. Status: {response.status_code}")
|
||||
logger.error(
|
||||
f"Failed to send welcome email to {email}. Status: {response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending welcome email to {email}: {str(e)}")
|
||||
return False
|
||||
|
||||
def send_account_locked_email(email: str, reset_token: str, failed_attempts: int, time_period: str) -> bool:
|
||||
|
||||
def send_account_locked_email(
|
||||
email: str, reset_token: str, failed_attempts: int, time_period: str
|
||||
) -> bool:
|
||||
"""
|
||||
Send an email informing that the account has been locked after login attempts
|
||||
|
||||
|
||||
Args:
|
||||
email: Recipient's email
|
||||
reset_token: Token to reset the password
|
||||
failed_attempts: Number of failed attempts
|
||||
time_period: Time period of the attempts
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if the email was sent successfully, False otherwise
|
||||
"""
|
||||
@ -178,29 +204,34 @@ def send_account_locked_email(email: str, reset_token: str, failed_attempts: int
|
||||
from_email = Email(settings.EMAIL_FROM)
|
||||
to_email = To(email)
|
||||
subject = "Security Alert - Account Locked"
|
||||
|
||||
|
||||
reset_link = f"{settings.APP_URL}/reset-password?token={reset_token}"
|
||||
|
||||
html_content = _render_template('account_locked', {
|
||||
'reset_link': reset_link,
|
||||
'user_name': email.split('@')[0],
|
||||
'failed_attempts': failed_attempts,
|
||||
'time_period': time_period,
|
||||
'current_year': datetime.now().year
|
||||
})
|
||||
|
||||
|
||||
html_content = _render_template(
|
||||
"account_locked",
|
||||
{
|
||||
"reset_link": reset_link,
|
||||
"user_name": email.split("@")[0],
|
||||
"failed_attempts": failed_attempts,
|
||||
"time_period": time_period,
|
||||
"current_year": datetime.now().year,
|
||||
},
|
||||
)
|
||||
|
||||
content = Content("text/html", html_content)
|
||||
|
||||
|
||||
mail = Mail(from_email, to_email, subject, content)
|
||||
response = sg.client.mail.send.post(request_body=mail.get())
|
||||
|
||||
|
||||
if response.status_code >= 200 and response.status_code < 300:
|
||||
logger.info(f"Account locked email sent to {email}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to send account locked email to {email}. Status: {response.status_code}")
|
||||
logger.error(
|
||||
f"Failed to send account locked email to {email}. Status: {response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending account locked email to {email}: {str(e)}")
|
||||
return False
|
||||
return False
|
||||
|
@ -9,6 +9,7 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]:
|
||||
"""Search for an MCP server by ID"""
|
||||
try:
|
||||
@ -21,9 +22,10 @@ def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]:
|
||||
logger.error(f"Error searching for MCP server {server_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error searching for MCP server"
|
||||
detail="Error searching for MCP server",
|
||||
)
|
||||
|
||||
|
||||
def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPServer]:
|
||||
"""Search for all MCP servers with pagination"""
|
||||
try:
|
||||
@ -32,9 +34,10 @@ def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPSer
|
||||
logger.error(f"Error searching for MCP servers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error searching for MCP servers"
|
||||
detail="Error searching for MCP servers",
|
||||
)
|
||||
|
||||
|
||||
def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer:
|
||||
"""Create a new MCP server"""
|
||||
try:
|
||||
@ -49,19 +52,22 @@ def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer:
|
||||
logger.error(f"Error creating MCP server: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error creating MCP server"
|
||||
detail="Error creating MCP server",
|
||||
)
|
||||
|
||||
def update_mcp_server(db: Session, server_id: uuid.UUID, server: MCPServerCreate) -> Optional[MCPServer]:
|
||||
|
||||
def update_mcp_server(
|
||||
db: Session, server_id: uuid.UUID, server: MCPServerCreate
|
||||
) -> Optional[MCPServer]:
|
||||
"""Update an existing MCP server"""
|
||||
try:
|
||||
db_server = get_mcp_server(db, server_id)
|
||||
if not db_server:
|
||||
return None
|
||||
|
||||
|
||||
for key, value in server.model_dump().items():
|
||||
setattr(db_server, key, value)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_server)
|
||||
logger.info(f"MCP server updated successfully: {server_id}")
|
||||
@ -71,16 +77,17 @@ def update_mcp_server(db: Session, server_id: uuid.UUID, server: MCPServerCreate
|
||||
logger.error(f"Error updating MCP server {server_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error updating MCP server"
|
||||
detail="Error updating MCP server",
|
||||
)
|
||||
|
||||
|
||||
def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool:
|
||||
"""Remove an MCP server"""
|
||||
try:
|
||||
db_server = get_mcp_server(db, server_id)
|
||||
if not db_server:
|
||||
return False
|
||||
|
||||
|
||||
db.delete(db_server)
|
||||
db.commit()
|
||||
logger.info(f"MCP server removed successfully: {server_id}")
|
||||
@ -90,5 +97,5 @@ def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool:
|
||||
logger.error(f"Error removing MCP server {server_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error removing MCP server"
|
||||
)
|
||||
detail="Error removing MCP server",
|
||||
)
|
||||
|
@ -12,26 +12,28 @@ from sqlalchemy.orm import Session
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class MCPService:
|
||||
def __init__(self):
|
||||
self.tools = []
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
async def _connect_to_mcp_server(self, server_config: Dict[str, Any]) -> Tuple[List[Any], Optional[AsyncExitStack]]:
|
||||
async def _connect_to_mcp_server(
|
||||
self, server_config: Dict[str, Any]
|
||||
) -> Tuple[List[Any], Optional[AsyncExitStack]]:
|
||||
"""Connect to a specific MCP server and return its tools."""
|
||||
try:
|
||||
# Determines the type of server (local or remote)
|
||||
if "url" in server_config:
|
||||
# Remote server (SSE)
|
||||
connection_params = SseServerParams(
|
||||
url=server_config["url"],
|
||||
headers=server_config.get("headers", {})
|
||||
url=server_config["url"], headers=server_config.get("headers", {})
|
||||
)
|
||||
else:
|
||||
# Local server (Stdio)
|
||||
command = server_config.get("command", "npx")
|
||||
args = server_config.get("args", [])
|
||||
|
||||
|
||||
# Adds environment variables if specified
|
||||
env = server_config.get("env", {})
|
||||
if env:
|
||||
@ -39,9 +41,7 @@ class MCPService:
|
||||
os.environ[key] = value
|
||||
|
||||
connection_params = StdioServerParameters(
|
||||
command=command,
|
||||
args=args,
|
||||
env=env
|
||||
command=command, args=args, env=env
|
||||
)
|
||||
|
||||
tools, exit_stack = await MCPToolset.from_server(
|
||||
@ -73,8 +73,10 @@ class MCPService:
|
||||
logger.warning(f"Removed {removed_count} incompatible tools.")
|
||||
|
||||
return filtered_tools
|
||||
|
||||
def _filter_tools_by_agent(self, tools: List[Any], agent_tools: List[str]) -> List[Any]:
|
||||
|
||||
def _filter_tools_by_agent(
|
||||
self, tools: List[Any], agent_tools: List[str]
|
||||
) -> List[Any]:
|
||||
"""Filters tools compatible with the agent."""
|
||||
filtered_tools = []
|
||||
for tool in tools:
|
||||
@ -83,7 +85,9 @@ class MCPService:
|
||||
filtered_tools.append(tool)
|
||||
return filtered_tools
|
||||
|
||||
async def build_tools(self, mcp_config: Dict[str, Any], db: Session) -> Tuple[List[Any], AsyncExitStack]:
|
||||
async def build_tools(
|
||||
self, mcp_config: Dict[str, Any], db: Session
|
||||
) -> Tuple[List[Any], AsyncExitStack]:
|
||||
"""Builds a list of tools from multiple MCP servers."""
|
||||
self.tools = []
|
||||
self.exit_stack = AsyncExitStack()
|
||||
@ -92,23 +96,25 @@ class MCPService:
|
||||
for server in mcp_config.get("mcp_servers", []):
|
||||
try:
|
||||
# Search for the MCP server in the database
|
||||
mcp_server = get_mcp_server(db, server['id'])
|
||||
mcp_server = get_mcp_server(db, server["id"])
|
||||
if not mcp_server:
|
||||
logger.warning(f"Servidor MCP não encontrado: {server['id']}")
|
||||
continue
|
||||
|
||||
# Prepares the server configuration
|
||||
server_config = mcp_server.config_json.copy()
|
||||
|
||||
|
||||
# Replaces the environment variables in the config_json
|
||||
if 'env' in server_config:
|
||||
for key, value in server_config['env'].items():
|
||||
if value.startswith('env@@'):
|
||||
env_key = value.replace('env@@', '')
|
||||
if env_key in server.get('envs', {}):
|
||||
server_config['env'][key] = server['envs'][env_key]
|
||||
if "env" in server_config:
|
||||
for key, value in server_config["env"].items():
|
||||
if value.startswith("env@@"):
|
||||
env_key = value.replace("env@@", "")
|
||||
if env_key in server.get("envs", {}):
|
||||
server_config["env"][key] = server["envs"][env_key]
|
||||
else:
|
||||
logger.warning(f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}")
|
||||
logger.warning(
|
||||
f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Connecting to MCP server: {mcp_server.name}")
|
||||
@ -117,22 +123,30 @@ class MCPService:
|
||||
if tools and exit_stack:
|
||||
# Filters incompatible tools
|
||||
filtered_tools = self._filter_incompatible_tools(tools)
|
||||
|
||||
|
||||
# Filters tools compatible with the agent
|
||||
agent_tools = server.get('tools', [])
|
||||
filtered_tools = self._filter_tools_by_agent(filtered_tools, agent_tools)
|
||||
agent_tools = server.get("tools", [])
|
||||
filtered_tools = self._filter_tools_by_agent(
|
||||
filtered_tools, agent_tools
|
||||
)
|
||||
self.tools.extend(filtered_tools)
|
||||
|
||||
|
||||
# Registers the exit_stack with the AsyncExitStack
|
||||
await self.exit_stack.enter_async_context(exit_stack)
|
||||
logger.info(f"Connected successfully. Added {len(filtered_tools)} tools.")
|
||||
logger.info(
|
||||
f"Connected successfully. Added {len(filtered_tools)} tools."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Failed to connect or no tools available for {mcp_server.name}")
|
||||
logger.warning(
|
||||
f"Failed to connect or no tools available for {mcp_server.name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to MCP server {server['id']}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"MCP Toolset created successfully. Total of {len(self.tools)} tools.")
|
||||
logger.info(
|
||||
f"MCP Toolset created successfully. Total of {len(self.tools)} tools."
|
||||
)
|
||||
|
||||
return self.tools, self.exit_stack
|
||||
return self.tools, self.exit_stack
|
||||
|
9
src/services/service_providers.py
Normal file
9
src/services/service_providers.py
Normal file
@ -0,0 +1,9 @@
|
||||
from src.config.settings import settings
|
||||
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from google.adk.sessions import DatabaseSessionService
|
||||
from google.adk.memory import InMemoryMemoryService
|
||||
|
||||
# Initialize service instances
|
||||
session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING)
|
||||
artifacts_service = InMemoryArtifactService()
|
||||
memory_service = InMemoryMemoryService()
|
@ -66,7 +66,7 @@ def get_session_by_id(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid session ID. Expected format: app_name_user_id",
|
||||
)
|
||||
|
||||
|
||||
parts = session_id.split("_", 1)
|
||||
if len(parts) != 2:
|
||||
logger.error(f"Invalid session ID format: {session_id}")
|
||||
@ -74,22 +74,22 @@ def get_session_by_id(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid session ID format. Expected format: app_name_user_id",
|
||||
)
|
||||
|
||||
|
||||
user_id, app_name = parts
|
||||
|
||||
|
||||
session = session_service.get_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
if session is None:
|
||||
logger.error(f"Session not found: {session_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Session not found: {session_id}",
|
||||
)
|
||||
|
||||
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching for session {session_id}: {str(e)}")
|
||||
@ -106,7 +106,7 @@ def delete_session(session_service: DatabaseSessionService, session_id: str) ->
|
||||
try:
|
||||
session = get_session_by_id(session_service, session_id)
|
||||
# If we get here, the session exists (get_session_by_id already validates)
|
||||
|
||||
|
||||
session_service.delete_session(
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
@ -131,10 +131,10 @@ def get_session_events(
|
||||
try:
|
||||
session = get_session_by_id(session_service, session_id)
|
||||
# If we get here, the session exists (get_session_by_id already validates)
|
||||
|
||||
if not hasattr(session, 'events') or session.events is None:
|
||||
|
||||
if not hasattr(session, "events") or session.events is None:
|
||||
return []
|
||||
|
||||
|
||||
return session.events
|
||||
except HTTPException:
|
||||
# Passes HTTP exceptions from get_session_by_id
|
||||
|
@ -9,6 +9,7 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]:
|
||||
"""Search for a tool by ID"""
|
||||
try:
|
||||
@ -21,9 +22,10 @@ def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]:
|
||||
logger.error(f"Error searching for tool {tool_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error searching for tool"
|
||||
detail="Error searching for tool",
|
||||
)
|
||||
|
||||
|
||||
def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]:
|
||||
"""Search for all tools with pagination"""
|
||||
try:
|
||||
@ -32,9 +34,10 @@ def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]:
|
||||
logger.error(f"Error searching for tools: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error searching for tools"
|
||||
detail="Error searching for tools",
|
||||
)
|
||||
|
||||
|
||||
def create_tool(db: Session, tool: ToolCreate) -> Tool:
|
||||
"""Creates a new tool"""
|
||||
try:
|
||||
@ -49,19 +52,20 @@ def create_tool(db: Session, tool: ToolCreate) -> Tool:
|
||||
logger.error(f"Error creating tool: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error creating tool"
|
||||
detail="Error creating tool",
|
||||
)
|
||||
|
||||
|
||||
def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[Tool]:
|
||||
"""Updates an existing tool"""
|
||||
try:
|
||||
db_tool = get_tool(db, tool_id)
|
||||
if not db_tool:
|
||||
return None
|
||||
|
||||
|
||||
for key, value in tool.model_dump().items():
|
||||
setattr(db_tool, key, value)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_tool)
|
||||
logger.info(f"Tool updated successfully: {tool_id}")
|
||||
@ -71,16 +75,17 @@ def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[T
|
||||
logger.error(f"Error updating tool {tool_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error updating tool"
|
||||
detail="Error updating tool",
|
||||
)
|
||||
|
||||
|
||||
def delete_tool(db: Session, tool_id: uuid.UUID) -> bool:
|
||||
"""Remove a tool"""
|
||||
try:
|
||||
db_tool = get_tool(db, tool_id)
|
||||
if not db_tool:
|
||||
return False
|
||||
|
||||
|
||||
db.delete(db_tool)
|
||||
db.commit()
|
||||
logger.info(f"Tool removed successfully: {tool_id}")
|
||||
@ -90,5 +95,5 @@ def delete_tool(db: Session, tool_id: uuid.UUID) -> bool:
|
||||
logger.error(f"Error removing tool {tool_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error removing tool"
|
||||
)
|
||||
detail="Error removing tool",
|
||||
)
|
||||
|
@ -3,7 +3,10 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from src.models.models import User, Client
|
||||
from src.schemas.user import UserCreate
|
||||
from src.utils.security import get_password_hash, verify_password, generate_token
|
||||
from src.services.email_service import send_verification_email, send_password_reset_email
|
||||
from src.services.email_service import (
|
||||
send_verification_email,
|
||||
send_password_reset_email,
|
||||
)
|
||||
from datetime import datetime, timedelta
|
||||
import uuid
|
||||
import logging
|
||||
@ -11,16 +14,22 @@ from typing import Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, client_id: Optional[uuid.UUID] = None) -> Tuple[Optional[User], str]:
|
||||
|
||||
def create_user(
|
||||
db: Session,
|
||||
user_data: UserCreate,
|
||||
is_admin: bool = False,
|
||||
client_id: Optional[uuid.UUID] = None,
|
||||
) -> Tuple[Optional[User], str]:
|
||||
"""
|
||||
Creates a new user in the system
|
||||
|
||||
Creates a new user in the system
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_data: User data to be created
|
||||
is_admin: If the user is an administrator
|
||||
client_id: Associated client ID (optional, a new one will be created if not provided)
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[User], str]: Tuple with the created user (or None in case of error) and status message
|
||||
"""
|
||||
@ -28,17 +37,19 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
|
||||
# Check if email already exists
|
||||
db_user = db.query(User).filter(User.email == user_data.email).first()
|
||||
if db_user:
|
||||
logger.warning(f"Attempt to register with existing email: {user_data.email}")
|
||||
logger.warning(
|
||||
f"Attempt to register with existing email: {user_data.email}"
|
||||
)
|
||||
return None, "Email already registered"
|
||||
|
||||
|
||||
# Create verification token
|
||||
verification_token = generate_token()
|
||||
token_expiry = datetime.utcnow() + timedelta(hours=24)
|
||||
|
||||
|
||||
# Start transaction
|
||||
user = None
|
||||
local_client_id = client_id
|
||||
|
||||
|
||||
try:
|
||||
# If not admin and no client_id, create an associated client
|
||||
if not is_admin and local_client_id is None:
|
||||
@ -46,7 +57,7 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
|
||||
db.add(client)
|
||||
db.flush() # Get the client ID
|
||||
local_client_id = client.id
|
||||
|
||||
|
||||
# Create user
|
||||
user = User(
|
||||
email=user_data.email,
|
||||
@ -56,52 +67,56 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
|
||||
is_active=False, # Inactive until email is verified
|
||||
email_verified=False,
|
||||
verification_token=verification_token,
|
||||
verification_token_expiry=token_expiry
|
||||
verification_token_expiry=token_expiry,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
|
||||
|
||||
# Send verification email
|
||||
email_sent = send_verification_email(user.email, verification_token)
|
||||
if not email_sent:
|
||||
logger.error(f"Failed to send verification email to {user.email}")
|
||||
# We don't do rollback here, we just log the error
|
||||
|
||||
|
||||
logger.info(f"User created successfully: {user.email}")
|
||||
return user, "User created successfully. Check your email to activate your account."
|
||||
|
||||
return (
|
||||
user,
|
||||
"User created successfully. Check your email to activate your account.",
|
||||
)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error creating user: {str(e)}")
|
||||
return None, f"Error creating user: {str(e)}"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error creating user: {str(e)}")
|
||||
return None, f"Unexpected error: {str(e)}"
|
||||
|
||||
|
||||
def verify_email(db: Session, token: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Verify the user's email using the provided token
|
||||
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
token: Verification token
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: Tuple with verification status and message
|
||||
"""
|
||||
try:
|
||||
# Search for user by token
|
||||
user = db.query(User).filter(User.verification_token == token).first()
|
||||
|
||||
|
||||
if not user:
|
||||
logger.warning(f"Attempt to verify with invalid token: {token}")
|
||||
return False, "Invalid verification token"
|
||||
|
||||
|
||||
# Check if the token has expired
|
||||
now = datetime.utcnow()
|
||||
expiry = user.verification_token_expiry
|
||||
|
||||
|
||||
# Ensure both dates are of the same type (aware or naive)
|
||||
if expiry.tzinfo is not None and now.tzinfo is None:
|
||||
# If expiry has timezone and now doesn't, convert now to have timezone
|
||||
@ -109,180 +124,201 @@ def verify_email(db: Session, token: str) -> Tuple[bool, str]:
|
||||
elif now.tzinfo is not None and expiry.tzinfo is None:
|
||||
# If now has timezone and expiry doesn't, convert expiry to have timezone
|
||||
expiry = expiry.replace(tzinfo=now.tzinfo)
|
||||
|
||||
|
||||
if expiry < now:
|
||||
logger.warning(f"Attempt to verify with expired token for user: {user.email}")
|
||||
logger.warning(
|
||||
f"Attempt to verify with expired token for user: {user.email}"
|
||||
)
|
||||
return False, "Verification token expired"
|
||||
|
||||
|
||||
# Update user
|
||||
user.email_verified = True
|
||||
user.is_active = True
|
||||
user.verification_token = None
|
||||
user.verification_token_expiry = None
|
||||
|
||||
|
||||
db.commit()
|
||||
logger.info(f"Email verified successfully for user: {user.email}")
|
||||
return True, "Email verified successfully. Your account is active."
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error verifying email: {str(e)}")
|
||||
return False, f"Error verifying email: {str(e)}"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error verifying email: {str(e)}")
|
||||
return False, f"Unexpected error: {str(e)}"
|
||||
|
||||
|
||||
def resend_verification(db: Session, email: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Resend the verification email
|
||||
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
email: User email
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: Tuple with operation status and message
|
||||
"""
|
||||
try:
|
||||
# Search for user by email
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
|
||||
|
||||
if not user:
|
||||
logger.warning(f"Attempt to resend verification email for non-existent email: {email}")
|
||||
logger.warning(
|
||||
f"Attempt to resend verification email for non-existent email: {email}"
|
||||
)
|
||||
return False, "Email not found"
|
||||
|
||||
|
||||
if user.email_verified:
|
||||
logger.info(f"Attempt to resend verification email for already verified email: {email}")
|
||||
logger.info(
|
||||
f"Attempt to resend verification email for already verified email: {email}"
|
||||
)
|
||||
return False, "Email already verified"
|
||||
|
||||
|
||||
# Generate new token
|
||||
verification_token = generate_token()
|
||||
token_expiry = datetime.utcnow() + timedelta(hours=24)
|
||||
|
||||
|
||||
# Update user
|
||||
user.verification_token = verification_token
|
||||
user.verification_token_expiry = token_expiry
|
||||
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
# Send email
|
||||
email_sent = send_verification_email(user.email, verification_token)
|
||||
if not email_sent:
|
||||
logger.error(f"Failed to resend verification email to {user.email}")
|
||||
return False, "Failed to send verification email"
|
||||
|
||||
|
||||
logger.info(f"Verification email resent successfully to: {user.email}")
|
||||
return True, "Verification email resent. Check your inbox."
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error resending verification: {str(e)}")
|
||||
return False, f"Error resending verification: {str(e)}"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error resending verification: {str(e)}")
|
||||
return False, f"Unexpected error: {str(e)}"
|
||||
|
||||
|
||||
def forgot_password(db: Session, email: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Initiates the password recovery process
|
||||
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
email: User email
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: Tuple with operation status and message
|
||||
"""
|
||||
try:
|
||||
# Search for user by email
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
|
||||
|
||||
if not user:
|
||||
# For security, we don't inform if the email exists or not
|
||||
logger.info(f"Attempt to recover password for non-existent email: {email}")
|
||||
return True, "If the email is registered, you will receive instructions to reset your password."
|
||||
|
||||
return (
|
||||
True,
|
||||
"If the email is registered, you will receive instructions to reset your password.",
|
||||
)
|
||||
|
||||
# Generate reset token
|
||||
reset_token = generate_token()
|
||||
token_expiry = datetime.utcnow() + timedelta(hours=1) # Token valid for 1 hour
|
||||
|
||||
|
||||
# Update user
|
||||
user.password_reset_token = reset_token
|
||||
user.password_reset_expiry = token_expiry
|
||||
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
# Send email
|
||||
email_sent = send_password_reset_email(user.email, reset_token)
|
||||
if not email_sent:
|
||||
logger.error(f"Failed to send password reset email to {user.email}")
|
||||
return False, "Failed to send password reset email"
|
||||
|
||||
|
||||
logger.info(f"Password reset email sent successfully to: {user.email}")
|
||||
return True, "If the email is registered, you will receive instructions to reset your password."
|
||||
|
||||
return (
|
||||
True,
|
||||
"If the email is registered, you will receive instructions to reset your password.",
|
||||
)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error processing password recovery: {str(e)}")
|
||||
return False, f"Error processing password recovery: {str(e)}"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error processing password recovery: {str(e)}")
|
||||
return False, f"Unexpected error: {str(e)}"
|
||||
|
||||
|
||||
def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Resets the user's password using the provided token
|
||||
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
token: Password reset token
|
||||
new_password: New password
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: Tuple with operation status and message
|
||||
"""
|
||||
try:
|
||||
# Search for user by token
|
||||
user = db.query(User).filter(User.password_reset_token == token).first()
|
||||
|
||||
|
||||
if not user:
|
||||
logger.warning(f"Attempt to reset password with invalid token: {token}")
|
||||
return False, "Invalid password reset token"
|
||||
|
||||
|
||||
# Check if the token has expired
|
||||
if user.password_reset_expiry < datetime.utcnow():
|
||||
logger.warning(f"Attempt to reset password with expired token for user: {user.email}")
|
||||
logger.warning(
|
||||
f"Attempt to reset password with expired token for user: {user.email}"
|
||||
)
|
||||
return False, "Password reset token expired"
|
||||
|
||||
|
||||
# Update password
|
||||
user.password_hash = get_password_hash(new_password)
|
||||
user.password_reset_token = None
|
||||
user.password_reset_expiry = None
|
||||
|
||||
|
||||
db.commit()
|
||||
logger.info(f"Password reset successfully for user: {user.email}")
|
||||
return True, "Password reset successfully. You can now login with your new password."
|
||||
|
||||
return (
|
||||
True,
|
||||
"Password reset successfully. You can now login with your new password.",
|
||||
)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error resetting password: {str(e)}")
|
||||
return False, f"Error resetting password: {str(e)}"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error resetting password: {str(e)}")
|
||||
return False, f"Unexpected error: {str(e)}"
|
||||
|
||||
|
||||
def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
||||
"""
|
||||
Searches for a user by email
|
||||
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
email: User email
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[User]: User found or None
|
||||
"""
|
||||
@ -292,15 +328,16 @@ def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
||||
logger.error(f"Error searching for user by email: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
"""
|
||||
Authenticates a user with email and password
|
||||
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
email: User email
|
||||
password: User password
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[User]: Authenticated user or None
|
||||
"""
|
||||
@ -313,75 +350,78 @@ def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
def get_admin_users(db: Session, skip: int = 0, limit: int = 100):
|
||||
"""
|
||||
Lists the admin users
|
||||
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
|
||||
|
||||
Returns:
|
||||
List[User]: List of admin users
|
||||
"""
|
||||
try:
|
||||
users = db.query(User).filter(User.is_admin == True).offset(skip).limit(limit).all()
|
||||
users = db.query(User).filter(User.is_admin).offset(skip).limit(limit).all()
|
||||
logger.info(f"List of admins: {len(users)} found")
|
||||
return users
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error listing admins: {str(e)}")
|
||||
return []
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error listing admins: {str(e)}")
|
||||
return []
|
||||
|
||||
|
||||
def create_admin_user(db: Session, user_data: UserCreate) -> Tuple[Optional[User], str]:
|
||||
"""
|
||||
Creates a new admin user
|
||||
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_data: User data to be created
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[User], str]: Tuple with the created user (or None in case of error) and status message
|
||||
"""
|
||||
return create_user(db, user_data, is_admin=True)
|
||||
|
||||
|
||||
def deactivate_user(db: Session, user_id: uuid.UUID) -> Tuple[bool, str]:
|
||||
"""
|
||||
Deactivates a user (does not delete, only marks as inactive)
|
||||
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: ID of the user to be deactivated
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: Tuple with operation status and message
|
||||
"""
|
||||
try:
|
||||
# Search for user by ID
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
|
||||
if not user:
|
||||
logger.warning(f"Attempt to deactivate non-existent user: {user_id}")
|
||||
return False, "User not found"
|
||||
|
||||
|
||||
# Deactivate user
|
||||
user.is_active = False
|
||||
|
||||
|
||||
db.commit()
|
||||
logger.info(f"User deactivated successfully: {user.email}")
|
||||
return True, "User deactivated successfully"
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error deactivating user: {str(e)}")
|
||||
return False, f"Error deactivating user: {str(e)}"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error deactivating user: {str(e)}")
|
||||
return False, f"Unexpected error: {str(e)}"
|
||||
return False, f"Unexpected error: {str(e)}"
|
||||
|
@ -3,23 +3,26 @@ import os
|
||||
import sys
|
||||
from src.config.settings import settings
|
||||
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
"""Custom formatter for logs"""
|
||||
|
||||
|
||||
grey = "\x1b[38;20m"
|
||||
yellow = "\x1b[33;20m"
|
||||
red = "\x1b[31;20m"
|
||||
bold_red = "\x1b[31;1m"
|
||||
reset = "\x1b[0m"
|
||||
|
||||
format_template = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
|
||||
|
||||
format_template = (
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
|
||||
)
|
||||
|
||||
FORMATS = {
|
||||
logging.DEBUG: grey + format_template + reset,
|
||||
logging.INFO: grey + format_template + reset,
|
||||
logging.WARNING: yellow + format_template + reset,
|
||||
logging.ERROR: red + format_template + reset,
|
||||
logging.CRITICAL: bold_red + format_template + reset
|
||||
logging.CRITICAL: bold_red + format_template + reset,
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
@ -27,33 +30,34 @@ class CustomFormatter(logging.Formatter):
|
||||
formatter = logging.Formatter(log_fmt)
|
||||
return formatter.format(record)
|
||||
|
||||
|
||||
def setup_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
Configures a custom logger
|
||||
|
||||
|
||||
Args:
|
||||
name: Logger name
|
||||
|
||||
|
||||
Returns:
|
||||
logging.Logger: Logger configurado
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
|
||||
|
||||
# Remove existing handlers to avoid duplication
|
||||
if logger.handlers:
|
||||
logger.handlers.clear()
|
||||
|
||||
|
||||
# Configure the logger level based on the environment variable or configuration
|
||||
log_level = getattr(logging, os.getenv("LOG_LEVEL", settings.LOG_LEVEL).upper())
|
||||
logger.setLevel(log_level)
|
||||
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(CustomFormatter())
|
||||
console_handler.setLevel(log_level)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
|
||||
# Prevent logs from being propagated to the root logger
|
||||
logger.propagate = False
|
||||
|
||||
return logger
|
||||
|
||||
return logger
|
||||
|
@ -11,41 +11,44 @@ from dataclasses import dataclass
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Fix bcrypt error with passlib
|
||||
if not hasattr(bcrypt, '__about__'):
|
||||
if not hasattr(bcrypt, "__about__"):
|
||||
|
||||
@dataclass
|
||||
class BcryptAbout:
|
||||
__version__: str = getattr(bcrypt, "__version__")
|
||||
|
||||
|
||||
setattr(bcrypt, "__about__", BcryptAbout())
|
||||
|
||||
# Context for password hashing using bcrypt
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Creates a password hash"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verifies if the provided password matches the stored hash"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def create_jwt_token(data: dict, expires_delta: timedelta = None) -> str:
|
||||
"""Creates a JWT token"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(
|
||||
minutes=settings.JWT_EXPIRATION_TIME
|
||||
)
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.JWT_EXPIRATION_TIME)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def generate_token(length: int = 32) -> str:
|
||||
"""Generates a secure token for email verification or password reset"""
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
token = ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
return token
|
||||
token = "".join(secrets.choice(alphabet) for _ in range(length))
|
||||
return token
|
||||
|
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Package initialization for tests
|
1
tests/api/__init__.py
Normal file
1
tests/api/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# API tests package
|
11
tests/api/test_root.py
Normal file
11
tests/api/test_root.py
Normal file
@ -0,0 +1,11 @@
|
||||
def test_read_root(client):
|
||||
"""
|
||||
Test that the root endpoint returns the correct response.
|
||||
"""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "message" in data
|
||||
assert "documentation" in data
|
||||
assert "version" in data
|
||||
assert "auth" in data
|
1
tests/services/__init__.py
Normal file
1
tests/services/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Services tests package
|
27
tests/services/test_auth_service.py
Normal file
27
tests/services/test_auth_service.py
Normal file
@ -0,0 +1,27 @@
|
||||
from src.services.auth_service import create_access_token
|
||||
from src.models.models import User
|
||||
import uuid
|
||||
|
||||
|
||||
def test_create_access_token():
|
||||
"""
|
||||
Test that an access token is created with the correct data.
|
||||
"""
|
||||
# Create a mock user
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@example.com",
|
||||
hashed_password="hashed_password",
|
||||
is_active=True,
|
||||
is_admin=False,
|
||||
name="Test User",
|
||||
client_id=uuid.uuid4(),
|
||||
)
|
||||
|
||||
# Create token
|
||||
token = create_access_token(user)
|
||||
|
||||
# Simple validation
|
||||
assert token is not None
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 0
|
Loading…
Reference in New Issue
Block a user